-
Notifications
You must be signed in to change notification settings - Fork 3
feat(adapter/nemo): add EventLoggingCallback for lifecycle monitoring #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
e5e5a51
feat: add StepTimerCallback for lifecycle monitoring
kkkapu 722a82c
refactor: simplify logging and unify StepTimerCallback hooks
kkkapu 8614a31
test: add unit tests and refine log format for StepTimerCallback
kkkapu a2ec8c5
refactor(adapter/nemo): rename to EventLoggingCallback and add some c…
kkkapu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
162 changes: 162 additions & 0 deletions
162
src/ml_flashpoint/adapter/nemo/event_logging_callback.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| # Copyright 2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Any, Dict | ||
|
|
||
| import lightning.pytorch as pl | ||
| import torch | ||
| from lightning.pytorch import callbacks as pl_callbacks | ||
| from lightning.pytorch.utilities.types import STEP_OUTPUT | ||
| from typing_extensions import override | ||
|
|
||
| from ml_flashpoint.core.mlf_logging import get_logger | ||
|
|
||
| _LOGGER = get_logger(__name__) | ||
|
|
||
|
|
||
| class EventLoggingCallback(pl_callbacks.Callback): | ||
| """ | ||
| A comprehensive logging callback to record timestamps for all key PyTorch Lightning | ||
| lifecycle events to monitor execution flow. | ||
| """ | ||
|
|
||
| def _log_event(self, hook_name: str) -> None: | ||
| _LOGGER.info(f"event={hook_name}") | ||
|
|
||
| @override | ||
| def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
| """Called when the train begins.""" | ||
| self._log_event("on_train_start") | ||
|
|
||
| @override | ||
| def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
| """Called when the train ends.""" | ||
| self._log_event("on_train_end") | ||
|
|
||
| @override | ||
| def on_train_batch_start( | ||
| self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int | ||
| ) -> None: | ||
| """Called when the train batch begins.""" | ||
| self._log_event("on_train_batch_start") | ||
|
|
||
| @override | ||
| def on_train_batch_end( | ||
| self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int | ||
| ) -> None: | ||
| """Called when the train batch ends.""" | ||
| self._log_event("on_train_batch_end") | ||
|
|
||
| @override | ||
| def on_validation_batch_start( | ||
| self, | ||
| trainer: "pl.Trainer", | ||
| pl_module: "pl.LightningModule", | ||
| batch: Any, | ||
| batch_idx: int, | ||
| dataloader_idx: int = 0, | ||
| ) -> None: | ||
| """Called when the validation batch begins.""" | ||
| self._log_event("on_validation_batch_start") | ||
|
|
||
| @override | ||
| def on_validation_batch_end( | ||
| self, | ||
| trainer: "pl.Trainer", | ||
| pl_module: "pl.LightningModule", | ||
| outputs: STEP_OUTPUT, | ||
| batch: Any, | ||
| batch_idx: int, | ||
| dataloader_idx: int = 0, | ||
| ) -> None: | ||
| """Called when the validation batch ends.""" | ||
| self._log_event("on_validation_batch_end") | ||
|
|
||
| @override | ||
| def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
| """Called when the test epoch begins.""" | ||
| self._log_event("on_test_epoch_start") | ||
|
|
||
| @override | ||
| def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
| """Called when the test epoch ends.""" | ||
| self._log_event("on_test_epoch_end") | ||
|
|
||
| @override | ||
| def on_test_batch_start( | ||
| self, | ||
| trainer: "pl.Trainer", | ||
| pl_module: "pl.LightningModule", | ||
| batch: Any, | ||
| batch_idx: int, | ||
| dataloader_idx: int = 0, | ||
| ) -> None: | ||
| """Called when the test batch begins.""" | ||
| self._log_event("on_test_batch_start") | ||
|
|
||
| @override | ||
| def on_test_batch_end( | ||
| self, | ||
| trainer: "pl.Trainer", | ||
| pl_module: "pl.LightningModule", | ||
| outputs: STEP_OUTPUT, | ||
| batch: Any, | ||
| batch_idx: int, | ||
| dataloader_idx: int = 0, | ||
| ) -> None: | ||
| """Called when the test batch ends.""" | ||
| self._log_event("on_test_batch_end") | ||
|
|
||
| @override | ||
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
| """Called when loading a checkpoint, implement to reload callback state.""" | ||
| self._log_event("load_state_dict") | ||
|
|
||
| @override | ||
| def on_save_checkpoint( | ||
| self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] | ||
| ) -> None: | ||
| """Called when saving a checkpoint.""" | ||
| self._log_event("on_save_checkpoint") | ||
|
|
||
| @override | ||
| def on_load_checkpoint( | ||
| self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] | ||
| ) -> None: | ||
| """Called when loading a model checkpoint, use to reload state.""" | ||
| self._log_event("on_load_checkpoint") | ||
|
|
||
| @override | ||
| def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None: | ||
| """Called before loss.backward().""" | ||
| self._log_event("on_before_backward") | ||
|
|
||
| @override | ||
| def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
| """Called after loss.backward() and before optimizers are stepped.""" | ||
| self._log_event("on_after_backward") | ||
|
|
||
| @override | ||
| def on_before_optimizer_step( | ||
| self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: torch.optim.Optimizer | ||
| ) -> None: | ||
| """Called before optimizer.step().""" | ||
| self._log_event("on_before_optimizer_step") | ||
|
|
||
| @override | ||
| def on_before_zero_grad( | ||
| self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: torch.optim.Optimizer | ||
| ) -> None: | ||
| """Called before optimizer.zero_grad().""" | ||
| self._log_event("on_before_zero_grad") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # Copyright 2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from unittest.mock import MagicMock | ||
|
|
||
| import lightning.pytorch as pl | ||
| import pytest | ||
| import torch | ||
|
|
||
| from ml_flashpoint.adapter.nemo.event_logging_callback import EventLoggingCallback | ||
|
|
||
| # Exhaustive list of all hooks implemented in EventLoggingCallback. | ||
| # Format: (method_name, extra_kwargs, expected_event_string) | ||
| HOOKS_TO_TEST = [ | ||
| ("on_train_start", {}, "on_train_start"), | ||
| ("on_train_end", {}, "on_train_end"), | ||
| ("on_train_batch_start", {"batch": None, "batch_idx": 0}, "on_train_batch_start"), | ||
| ("on_train_batch_end", {"outputs": None, "batch": None, "batch_idx": 0}, "on_train_batch_end"), | ||
| ("on_validation_batch_start", {"batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_validation_batch_start"), | ||
| ( | ||
| "on_validation_batch_end", | ||
| {"outputs": None, "batch": None, "batch_idx": 0, "dataloader_idx": 0}, | ||
| "on_validation_batch_end", | ||
| ), | ||
| ("on_test_epoch_start", {}, "on_test_epoch_start"), | ||
| ("on_test_epoch_end", {}, "on_test_epoch_end"), | ||
| ("on_test_batch_start", {"batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_test_batch_start"), | ||
| ("on_test_batch_end", {"outputs": None, "batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_test_batch_end"), | ||
| ("load_state_dict", {"state_dict": {}}, "load_state_dict"), | ||
| ("on_save_checkpoint", {"checkpoint": {}}, "on_save_checkpoint"), | ||
| ("on_load_checkpoint", {"checkpoint": {}}, "on_load_checkpoint"), | ||
| ("on_before_backward", {"loss": torch.tensor(0.0)}, "on_before_backward"), | ||
| ("on_after_backward", {}, "on_after_backward"), | ||
| ("on_before_optimizer_step", {"optimizer": MagicMock(spec=torch.optim.Optimizer)}, "on_before_optimizer_step"), | ||
| ("on_before_zero_grad", {"optimizer": MagicMock(spec=torch.optim.Optimizer)}, "on_before_zero_grad"), | ||
| ] | ||
|
|
||
|
|
||
| def test_is_subtype_of_pytorch_lightning_callback(): | ||
| """Verify inheritance to ensure compatibility with PyTorch Lightning.""" | ||
| assert issubclass(EventLoggingCallback, pl.callbacks.Callback) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("hook_name, kwargs, expected_event", HOOKS_TO_TEST) | ||
| def test_event_logging_hooks_log_correctly(mocker, hook_name, kwargs, expected_event): | ||
| """ | ||
| Tests that every lifecycle hook in EventLoggingCallback logs the correct event. | ||
| """ | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can add "# Given # When # Then" comments to highlight different parts of the test. setup, execution, assertions. makes it easier to see what is being asserted and which parts are just setup vs execution |
||
| # Given | ||
| mock_logger = mocker.patch("ml_flashpoint.adapter.nemo.event_logging_callback._LOGGER") | ||
| callback = EventLoggingCallback() | ||
|
|
||
| # Mock Trainer and LightningModule as required by the PyTorch Lightning API. | ||
| trainer = mocker.MagicMock(spec=pl.Trainer) | ||
| pl_module = mocker.MagicMock(spec=pl.LightningModule) | ||
|
|
||
| # When | ||
| # Dynamically fetch the method to test. | ||
| hook_method = getattr(callback, hook_name) | ||
|
|
||
| # load_state_dict does not follow the (trainer, pl_module) signature. | ||
| if hook_name == "load_state_dict": | ||
| hook_method(**kwargs) | ||
| else: | ||
| hook_method(trainer=trainer, pl_module=pl_module, **kwargs) | ||
|
|
||
| # Then | ||
| # Verify the log content matches the implementation of _log_event. | ||
| # LOGGER.info(f"event={hook_name}") | ||
| expected_log_msg = f"event={expected_event}" | ||
| mock_logger.info.assert_called_once_with(expected_log_msg) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we generate some basic tests to validate some of the key events are logged, like
on_train_*,on_train_batch_*