Skip to content

Conversation

@MalyalaKarthik66
Copy link
Contributor

@MalyalaKarthik66 MalyalaKarthik66 commented Nov 10, 2025

Fix: #21771
Added a hard=True option to keras.callbacks.TerminateOnNaN that raises a RuntimeError immediately when NaN or Inf losses occur.

This initial implementation focuses on strict failure behavior. It intentionally:

  • Preserves the existing graceful mode (hard=False) for backward compatibility
  • Ensures immediate termination without invoking later callbacks (e.g., BackupAndRestore)
  • Guarantees consistent behavior across supported backends
  • Includes unit tests covering both graceful and hard termination modes

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new callback, HardTerminateOnNaN, to the Keras framework, providing a more stringent approach to handling numerical instability during model training. Unlike its predecessor, TerminateOnNaN, which attempts a controlled stop, this new callback immediately halts training by raising a RuntimeError when NaN or Inf losses are detected. This design choice is crucial for scenarios requiring immediate cessation of training and the preservation of intermediate states, as it bypasses the on_train_end lifecycle hooks that might otherwise clean up or alter critical data.

Highlights

  • New Callback Introduced: A new callback, HardTerminateOnNaN, has been added to keras.callbacks.
  • Strict Error Handling: This callback immediately raises a RuntimeError when NaN or Inf loss values are encountered during training, providing a stricter failure mechanism compared to TerminateOnNaN.
  • Preservation of State: Unlike TerminateOnNaN, HardTerminateOnNaN prevents on_train_end hooks from being triggered, which is beneficial for other callbacks (e.g., BackupAndRestore) that need to preserve their state without cleanup.
  • Comprehensive Unit Tests: New unit tests have been added to validate the correct termination behavior of HardTerminateOnNaN, including scenarios for raising errors, not calling on_train_end, preserving backup directories, and normal training.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a HardTerminateOnNaN callback designed to immediately halt training by raising a RuntimeError upon encountering a NaN or Inf loss. The goal is to prevent cleanup hooks like on_train_end from executing, thereby preserving the training state for debugging. My review identified a critical issue with this core premise: Keras' training loop utilizes a try...finally block, which ensures on_train_end is always called, even when an exception occurs. This means the callback does not function as described. The tests also appear to be based on this incorrect assumption. I have provided detailed feedback on this fundamental issue. Additionally, I've included a suggestion to improve the test implementation for better consistency with the existing test suite.

Comment on lines 24 to 66
class HardTerminateOnNaN(Callback):
"""Callback that terminates training immediately
when NaN/Inf loss is detected.
This callback raises a RuntimeError when a NaN or Inf loss is encountered,
which immediately stops training without triggering `on_train_end()` hooks
for other callbacks. This is useful when you want to preserve backup states
or prevent early stopping from restoring weights after a NaN failure.
Unlike `TerminateOnNaN`, which gracefully stops training using
`model.stop_training = True` and triggers all callback cleanup methods,
`HardTerminateOnNaN` crashes the training loop immediately.
Example:
```
callback = keras.callbacks.HardTerminateOnNaN()
model.fit(x, y, callbacks=[callback])
```
"""

def __init__(self):
super().__init__()
self._supports_tf_logs = True

def on_batch_end(self, batch, logs=None):
"""Check for NaN/Inf loss at the end of each batch.
Args:
batch: Integer, index of batch within the current epoch.
logs: Dict, contains the return value of `model.train_step()`.
Raises:
RuntimeError: If loss is NaN or Inf.
"""
logs = logs or {}
loss = logs.get("loss")
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
raise RuntimeError(
f"NaN or Inf loss encountered at batch {batch}. "
f"Loss value: {loss}. Terminating training immediately."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The premise of this callback appears to be that raising a RuntimeError will prevent on_train_end hooks from being called. However, the fit() method in Keras' Trainer class wraps the training loop in a try...finally block. This guarantees that callbacks.on_train_end() is executed even when an exception is raised during training.

This means that cleanup logic, such as deleting checkpoints in BackupAndRestore.on_train_end, will still run. This behavior contradicts the main goal of this "hard" termination callback, as stated in the docstring: "which immediately stops training without triggering on_train_end() hooks for other callbacks".

The tests for this callback also seem to reflect some confusion around this behavior:

  • test_hard_terminate_does_not_trigger_on_train_end asserts that on_train_end is not called, which seems incorrect given the fit loop's implementation.
  • test_hard_terminate_preserves_backup works around this by monkeypatching on_train_end, which sidesteps testing the actual behavior in a real-world scenario where the backup would be deleted.

To achieve the desired behavior, a different mechanism might be necessary. For example, you could introduce a state on the model that other callbacks can check within their on_train_end methods to determine if they should skip their cleanup logic due to a hard termination.


# Create a fake file in the backup folder
fake_file = os.path.join(backup_dir, "checkpoint.txt")
open(fake_file, "w").write("dummy checkpoint")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's a best practice to use a with statement for file operations. This ensures the file is properly closed even if errors occur during the write operation.

For example:

with open(fake_file, "w") as f:
    f.write("dummy checkpoint")

Comment on lines 111 to 116
assert os.path.exists(backup_dir), (
f"Backup dir deleted: {backup_dir}"
)
assert os.path.exists(fake_file), (
"Backup file missing unexpectedly."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the testing.TestCase base class (which inherits from unittest.TestCase), it's better to use self.assertTrue() for assertions instead of the native assert statement.

            self.assertTrue(os.path.exists(backup_dir), f"Backup dir deleted: {backup_dir}")
            self.assertTrue(os.path.exists(fake_file), "Backup file missing unexpectedly.")

@codecov-commenter
Copy link

codecov-commenter commented Nov 13, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.48%. Comparing base (19ca9c1) to head (7eb295b).
⚠️ Report is 4 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21841      +/-   ##
==========================================
- Coverage   82.66%   82.48%   -0.18%     
==========================================
  Files         577      577              
  Lines       59477    59512      +35     
  Branches     9329     9331       +2     
==========================================
- Hits        49167    49091      -76     
- Misses       7907     8010     +103     
- Partials     2403     2411       +8     
Flag Coverage Δ
keras 82.31% <100.00%> (-0.18%) ⬇️
keras-jax 62.90% <100.00%> (-0.41%) ⬇️
keras-numpy 57.55% <12.50%> (+<0.01%) ⬆️
keras-openvino 34.35% <12.50%> (+<0.01%) ⬆️
keras-tensorflow 64.13% <100.00%> (+<0.01%) ⬆️
keras-torch 63.61% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@fchollet
Copy link
Collaborator

Thanks for the PR! This sounds like it should be an option in TerminateOnNaN instead of being a separate callback.

@MalyalaKarthik66 MalyalaKarthik66 changed the title Add HardTerminateOnNaN callback for strict NaN/Inf loss handling Add hard option to TerminateOnNaN for immediate termination on NaN/Inf losses Nov 14, 2025
@MalyalaKarthik66
Copy link
Contributor Author

@fchollet
Thanks for the review! I’ve updated the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TerminateOnNaN callback triggers on_train_end(), causing unintended side effects in other callbacks

4 participants