-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add hard option to TerminateOnNaN for immediate termination on NaN/Inf losses #21841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add hard option to TerminateOnNaN for immediate termination on NaN/Inf losses #21841
Conversation
Summary of ChangesHello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new callback, Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a HardTerminateOnNaN callback designed to immediately halt training by raising a RuntimeError upon encountering a NaN or Inf loss. The goal is to prevent cleanup hooks like on_train_end from executing, thereby preserving the training state for debugging. My review identified a critical issue with this core premise: Keras' training loop utilizes a try...finally block, which ensures on_train_end is always called, even when an exception occurs. This means the callback does not function as described. The tests also appear to be based on this incorrect assumption. I have provided detailed feedback on this fundamental issue. Additionally, I've included a suggestion to improve the test implementation for better consistency with the existing test suite.
| class HardTerminateOnNaN(Callback): | ||
| """Callback that terminates training immediately | ||
| when NaN/Inf loss is detected. | ||
| This callback raises a RuntimeError when a NaN or Inf loss is encountered, | ||
| which immediately stops training without triggering `on_train_end()` hooks | ||
| for other callbacks. This is useful when you want to preserve backup states | ||
| or prevent early stopping from restoring weights after a NaN failure. | ||
| Unlike `TerminateOnNaN`, which gracefully stops training using | ||
| `model.stop_training = True` and triggers all callback cleanup methods, | ||
| `HardTerminateOnNaN` crashes the training loop immediately. | ||
| Example: | ||
| ``` | ||
| callback = keras.callbacks.HardTerminateOnNaN() | ||
| model.fit(x, y, callbacks=[callback]) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self._supports_tf_logs = True | ||
|
|
||
| def on_batch_end(self, batch, logs=None): | ||
| """Check for NaN/Inf loss at the end of each batch. | ||
| Args: | ||
| batch: Integer, index of batch within the current epoch. | ||
| logs: Dict, contains the return value of `model.train_step()`. | ||
| Raises: | ||
| RuntimeError: If loss is NaN or Inf. | ||
| """ | ||
| logs = logs or {} | ||
| loss = logs.get("loss") | ||
| if loss is not None: | ||
| if np.isnan(loss) or np.isinf(loss): | ||
| raise RuntimeError( | ||
| f"NaN or Inf loss encountered at batch {batch}. " | ||
| f"Loss value: {loss}. Terminating training immediately." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The premise of this callback appears to be that raising a RuntimeError will prevent on_train_end hooks from being called. However, the fit() method in Keras' Trainer class wraps the training loop in a try...finally block. This guarantees that callbacks.on_train_end() is executed even when an exception is raised during training.
This means that cleanup logic, such as deleting checkpoints in BackupAndRestore.on_train_end, will still run. This behavior contradicts the main goal of this "hard" termination callback, as stated in the docstring: "which immediately stops training without triggering on_train_end() hooks for other callbacks".
The tests for this callback also seem to reflect some confusion around this behavior:
test_hard_terminate_does_not_trigger_on_train_endasserts thaton_train_endis not called, which seems incorrect given thefitloop's implementation.test_hard_terminate_preserves_backupworks around this by monkeypatchingon_train_end, which sidesteps testing the actual behavior in a real-world scenario where the backup would be deleted.
To achieve the desired behavior, a different mechanism might be necessary. For example, you could introduce a state on the model that other callbacks can check within their on_train_end methods to determine if they should skip their cleanup logic due to a hard termination.
|
|
||
| # Create a fake file in the backup folder | ||
| fake_file = os.path.join(backup_dir, "checkpoint.txt") | ||
| open(fake_file, "w").write("dummy checkpoint") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert os.path.exists(backup_dir), ( | ||
| f"Backup dir deleted: {backup_dir}" | ||
| ) | ||
| assert os.path.exists(fake_file), ( | ||
| "Backup file missing unexpectedly." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the testing.TestCase base class (which inherits from unittest.TestCase), it's better to use self.assertTrue() for assertions instead of the native assert statement.
self.assertTrue(os.path.exists(backup_dir), f"Backup dir deleted: {backup_dir}")
self.assertTrue(os.path.exists(fake_file), "Backup file missing unexpectedly.")
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21841 +/- ##
==========================================
- Coverage 82.66% 82.48% -0.18%
==========================================
Files 577 577
Lines 59477 59512 +35
Branches 9329 9331 +2
==========================================
- Hits 49167 49091 -76
- Misses 7907 8010 +103
- Partials 2403 2411 +8
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Thanks for the PR! This sounds like it should be an option in |
|
@fchollet |
Fix: #21771
Added a hard=True option to keras.callbacks.TerminateOnNaN that raises a RuntimeError immediately when NaN or Inf losses occur.
This initial implementation focuses on strict failure behavior. It intentionally: