Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions keras/src/callbacks/hard_terminate_on_nan_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Tests for TerminateOnNaN callback."""

import os
import tempfile

import numpy as np
import pytest

import keras
from keras.src import backend
from keras.src import layers
from keras.src import models
from keras.src import testing
from keras.src.callbacks import BackupAndRestore
from keras.src.callbacks import TerminateOnNaN


@pytest.mark.skipif(
backend.backend() in ["numpy", "openvino"],
reason="TerminateOnNaN not supported for NumPy or OpenVINO backend",
)
class TerminateOnNaNTest(testing.TestCase):
"""Test suite for TerminateOnNaN callback."""

def test_terminate_on_nan_graceful_stop(self):
"""Test that TerminateOnNaN (default) gracefully stops training."""
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

# Create data that will cause NaN
x = np.array([[1.0], [2.0]])
y = np.array([[np.inf], [np.inf]])

callback = TerminateOnNaN(hard=False)

# Training should complete without raising RuntimeError
# (graceful stop via stop_training = True)
history = model.fit(
x, y, epochs=2, batch_size=1, callbacks=[callback], verbose=0
)

# Training should stop early, not complete all epochs
# 2 epochs * 2 batches = 4
self.assertLess(len(history.history["loss"]), 4)

def test_terminate_on_nan_hard_raises_error(self):
"""Test that TerminateOnNaN(hard=True) raises
RuntimeError on NaN loss.
"""
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

# Create data that will cause NaN
x = np.array([[1.0], [2.0]])
y = np.array([[np.inf], [np.inf]])

callback = TerminateOnNaN(hard=True)

# Training should raise RuntimeError
with pytest.raises(RuntimeError, match="NaN or Inf loss encountered"):
model.fit(
x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0
)

def test_hard_terminate_does_not_trigger_on_train_end(self):
"""Test that on_train_end is NOT called when
TerminateOnNaN(hard=True) raises.
"""

# Create a custom callback to track if on_train_end was called
class TrackingCallback(keras.src.callbacks.Callback):
def __init__(self):
super().__init__()
self.train_end_called = False

def on_train_end(self, logs=None):
self.train_end_called = True

model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

x = np.array([[1.0]])
y = np.array([[np.inf]])

tracking_callback = TrackingCallback()
hard_terminate_callback = TerminateOnNaN(hard=True)

# Should raise RuntimeError
with pytest.raises(RuntimeError):
model.fit(
x,
y,
epochs=1,
callbacks=[tracking_callback, hard_terminate_callback],
verbose=0,
)

# on_train_end should NOT have been called
self.assertFalse(tracking_callback.train_end_called)

def test_hard_terminate_preserves_backup(self):
"""Ensure BackupAndRestore directory is preserved when
TerminateOnNaN(hard=True) triggers.
"""
with tempfile.TemporaryDirectory() as tmpdir:
backup_dir = os.path.join(tmpdir, "backups")
os.makedirs(backup_dir, exist_ok=True)

# Create a fake file in the backup folder
fake_file = os.path.join(backup_dir, "checkpoint.txt")
with open(fake_file, "w") as f:
f.write("dummy checkpoint")

# Define a simple model
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

# Data that causes NaN
x_nan = np.array([[1.0]])
y_nan = np.array([[np.inf]])

hard_terminate_callback = TerminateOnNaN(hard=True)
backup_callback = BackupAndRestore(backup_dir=backup_dir)

# Monkeypatch BackupAndRestore to prevent cleanup on train_end
backup_callback.on_train_end = lambda logs=None: None

# Training should raise RuntimeError
with pytest.raises(RuntimeError):
model.fit(
x_nan,
y_nan,
epochs=1,
callbacks=[backup_callback, hard_terminate_callback],
verbose=0,
)

# Verify backup directory still exists and file inside is untouched
self.assertTrue(
os.path.exists(backup_dir),
f"Backup dir deleted: {backup_dir}",
)
self.assertTrue(
os.path.exists(fake_file),
"Backup file missing unexpectedly.",
)

def test_normal_training_does_not_raise(self):
"""Test that TerminateOnNaN does not raise on normal training."""
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

x = np.array([[1.0], [2.0]])
y = np.array([[1.0], [2.0]])

# Test both hard=False and hard=True with normal data
for hard in [False, True]:
callback = TerminateOnNaN(hard=hard)

# Should complete without raising RuntimeError
history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0)

# Should have completed 2 epochs
self.assertEqual(len(history.history["loss"]), 2)

def test_hard_terminate_stops_on_later_batch(self):
"""Ensure TerminateOnNaN(hard=True) stops training
if NaN appears in later batch.
"""
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

# Batch 1: normal loss, Batch 2: NaN loss
x = np.array([[1.0], [2.0]])
y = np.array([[1.0], [np.inf]]) # NaN/Inf appears only in 2nd batch

callback = TerminateOnNaN(hard=True)

with pytest.raises(RuntimeError) as exc:
model.fit(
x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0
)

# Check that error message references batch 1
# (0-based indexing, second batch)
assert any(f"batch {i}" in str(exc.value) for i in [0, 1])
59 changes: 54 additions & 5 deletions keras/src/callbacks/terminate_on_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,63 @@

@keras_export("keras.callbacks.TerminateOnNaN")
class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered."""
"""Callback that terminates training when a NaN loss is encountered.

This callback monitors the loss value during training
and terminates training when a NaN or Inf loss is detected.
By default, training is stopped gracefully
by setting `model.stop_training = True`, which triggers all callback cleanup
methods including `on_train_end()`.

Alternatively, you can use `hard=True` to immediately raise a RuntimeError
when NaN/Inf is detected. This hard termination prevents `on_train_end()`
from being called on other callbacks, which is useful for preserving backup
states or preventing unintended cleanup when training fails.

Args:
hard: Boolean, default False. If False, uses graceful stop via
`model.stop_training = True`. If True, immediately raises
RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.

Example:

```
# Graceful termination (default)
callback = keras.callbacks.TerminateOnNaN()
model.fit(x, y, callbacks=[callback])

# Hard termination (strict failure)
callback = keras.callbacks.TerminateOnNaN(hard=True)
model.fit(x, y, callbacks=[callback])
```
"""

def __init__(self, hard: bool = False):
super().__init__()
self.hard = hard
self._supports_tf_logs = True

def on_batch_end(self, batch, logs=None):
"""Check for NaN/Inf loss at the end of each batch.

Args:
batch: Integer, index of batch within the current epoch.
logs: Dict, contains the return value of `model.train_step()`.

Raises:
RuntimeError: If loss is NaN/Inf and hard=True.
"""
logs = logs or {}
loss = logs.get("loss")
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
io_utils.print_msg(
f"Batch {batch}: Invalid loss, terminating training"
)
self.model.stop_training = True
if self.hard:
raise RuntimeError(
f"NaN or Inf loss encountered at batch {batch}. "
f"Loss value: {loss}. Terminating training immediately."
)
else:
io_utils.print_msg(
f"Batch {batch}: Invalid loss, terminating training"
)
self.model.stop_training = True