Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions src/ml_flashpoint/adapter/nemo/event_logging_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict

import lightning.pytorch as pl
import torch
from lightning.pytorch import callbacks as pl_callbacks
from lightning.pytorch.utilities.types import STEP_OUTPUT
from typing_extensions import override

from ml_flashpoint.core.mlf_logging import get_logger

_LOGGER = get_logger(__name__)


class EventLoggingCallback(pl_callbacks.Callback):
"""
A comprehensive logging callback to record timestamps for all key PyTorch Lightning
lifecycle events to monitor execution flow.
"""

def _log_event(self, hook_name: str) -> None:
_LOGGER.info(f"event={hook_name}")

@override
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train begins."""
self._log_event("on_train_start")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we generate some basic tests to validate some of the key events are logged, like on_train_*, on_train_batch_*


@override
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train ends."""
self._log_event("on_train_end")

@override
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
"""Called when the train batch begins."""
self._log_event("on_train_batch_start")

@override
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Called when the train batch ends."""
self._log_event("on_train_batch_end")

@override
def on_validation_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the validation batch begins."""
self._log_event("on_validation_batch_start")

@override
def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the validation batch ends."""
self._log_event("on_validation_batch_end")

@override
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the test epoch begins."""
self._log_event("on_test_epoch_start")

@override
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the test epoch ends."""
self._log_event("on_test_epoch_end")

@override
def on_test_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the test batch begins."""
self._log_event("on_test_batch_start")

@override
def on_test_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the test batch ends."""
self._log_event("on_test_batch_end")

@override
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload callback state."""
self._log_event("load_state_dict")

@override
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
"""Called when saving a checkpoint."""
self._log_event("on_save_checkpoint")

@override
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
"""Called when loading a model checkpoint, use to reload state."""
self._log_event("on_load_checkpoint")

@override
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
"""Called before loss.backward()."""
self._log_event("on_before_backward")

@override
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called after loss.backward() and before optimizers are stepped."""
self._log_event("on_after_backward")

@override
def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: torch.optim.Optimizer
) -> None:
"""Called before optimizer.step()."""
self._log_event("on_before_optimizer_step")

@override
def on_before_zero_grad(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: torch.optim.Optimizer
) -> None:
"""Called before optimizer.zero_grad()."""
self._log_event("on_before_zero_grad")
83 changes: 83 additions & 0 deletions tests/adapter/nemo/test_event_logging_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock

import lightning.pytorch as pl
import pytest
import torch

from ml_flashpoint.adapter.nemo.event_logging_callback import EventLoggingCallback

# Exhaustive list of all hooks implemented in EventLoggingCallback.
# Format: (method_name, extra_kwargs, expected_event_string)
HOOKS_TO_TEST = [
("on_train_start", {}, "on_train_start"),
("on_train_end", {}, "on_train_end"),
("on_train_batch_start", {"batch": None, "batch_idx": 0}, "on_train_batch_start"),
("on_train_batch_end", {"outputs": None, "batch": None, "batch_idx": 0}, "on_train_batch_end"),
("on_validation_batch_start", {"batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_validation_batch_start"),
(
"on_validation_batch_end",
{"outputs": None, "batch": None, "batch_idx": 0, "dataloader_idx": 0},
"on_validation_batch_end",
),
("on_test_epoch_start", {}, "on_test_epoch_start"),
("on_test_epoch_end", {}, "on_test_epoch_end"),
("on_test_batch_start", {"batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_test_batch_start"),
("on_test_batch_end", {"outputs": None, "batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_test_batch_end"),
("load_state_dict", {"state_dict": {}}, "load_state_dict"),
("on_save_checkpoint", {"checkpoint": {}}, "on_save_checkpoint"),
("on_load_checkpoint", {"checkpoint": {}}, "on_load_checkpoint"),
("on_before_backward", {"loss": torch.tensor(0.0)}, "on_before_backward"),
("on_after_backward", {}, "on_after_backward"),
("on_before_optimizer_step", {"optimizer": MagicMock(spec=torch.optim.Optimizer)}, "on_before_optimizer_step"),
("on_before_zero_grad", {"optimizer": MagicMock(spec=torch.optim.Optimizer)}, "on_before_zero_grad"),
]


def test_is_subtype_of_pytorch_lightning_callback():
"""Verify inheritance to ensure compatibility with PyTorch Lightning."""
assert issubclass(EventLoggingCallback, pl.callbacks.Callback)


@pytest.mark.parametrize("hook_name, kwargs, expected_event", HOOKS_TO_TEST)
def test_event_logging_hooks_log_correctly(mocker, hook_name, kwargs, expected_event):
"""
Tests that every lifecycle hook in EventLoggingCallback logs the correct event.
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can add "# Given # When # Then" comments to highlight different parts of the test. setup, execution, assertions. makes it easier to see what is being asserted and which parts are just setup vs execution

# Given
mock_logger = mocker.patch("ml_flashpoint.adapter.nemo.event_logging_callback._LOGGER")
callback = EventLoggingCallback()

# Mock Trainer and LightningModule as required by the PyTorch Lightning API.
trainer = mocker.MagicMock(spec=pl.Trainer)
pl_module = mocker.MagicMock(spec=pl.LightningModule)

# When
# Dynamically fetch the method to test.
hook_method = getattr(callback, hook_name)

# load_state_dict does not follow the (trainer, pl_module) signature.
if hook_name == "load_state_dict":
hook_method(**kwargs)
else:
hook_method(trainer=trainer, pl_module=pl_module, **kwargs)

# Then
# Verify the log content matches the implementation of _log_event.
# LOGGER.info(f"event={hook_name}")
expected_log_msg = f"event={expected_event}"
mock_logger.info.assert_called_once_with(expected_log_msg)
Loading