diff --git a/torchrec/metrics/cpu_offloaded_metric_module.py b/torchrec/metrics/cpu_offloaded_metric_module.py index f6ae2b7c1..83a9c6e48 100644 --- a/torchrec/metrics/cpu_offloaded_metric_module.py +++ b/torchrec/metrics/cpu_offloaded_metric_module.py @@ -474,9 +474,6 @@ def sync(self) -> None: ) self.comms_module.load_pre_compute_states(aggregated_states) - # Sync _trained_batches to comms module - self.comms_module._trained_batches.copy_(self._trained_batches) - logger.info("CPUOffloadedRecMetricModule synced.") @override diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 02baae328..4ab3bb162 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -202,9 +202,7 @@ def __init__( self.rec_metrics = rec_metrics if rec_metrics else RecMetricList([]) self.throughput_metric = throughput_metric self.state_metrics = state_metrics if state_metrics else {} - - self.register_buffer("_trained_batches", torch.tensor(0), persistent=True) - + self.trained_batches: int = 0 self.batch_size = batch_size self.world_size = world_size self.oom_count = 0 @@ -230,15 +228,6 @@ def __init__( ) self.last_compute_time = -1.0 - @property - def trained_batches(self) -> int: - # .trained_batches should return an int - return int(self._trained_batches.item()) - - @trained_batches.setter - def trained_batches(self, value: int) -> None: - self._trained_batches.fill_(int(value)) - def _update_rec_metrics( self, model_out: Dict[str, torch.Tensor], **kwargs: Any ) -> None: @@ -271,7 +260,7 @@ def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None: self._update_rec_metrics(model_out, **kwargs) if self.throughput_metric: self.throughput_metric.update() - self._trained_batches.add_(1) + self.trained_batches += 1 def _adjust_compute_interval(self) -> None: """ diff --git a/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py b/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py index 3f0865e2b..3c3d557e3 100644 --- a/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py +++ b/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py @@ -341,7 +341,6 @@ def test_state_dict_save_load(self) -> None: "rec_metrics.rec_metrics.0._metrics_computations.0.state_3": torch.tensor( [6.0] ), - "_trained_batches": torch.tensor(0), }, )