diff --git a/torchrec/metrics/cpu_comms_metric_module.py b/torchrec/metrics/cpu_comms_metric_module.py new file mode 100644 index 000000000..2eae617cc --- /dev/null +++ b/torchrec/metrics/cpu_comms_metric_module.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging +from typing import Any, cast, Dict + +from torch import nn + +from torch.profiler import record_function + +from torchrec.metrics.metric_module import RecMetricModule +from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot +from torchrec.metrics.rec_metric import ( + RecComputeMode, + RecMetric, + RecMetricComputation, + RecMetricList, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +class CPUCommsRecMetricModule(RecMetricModule): + """ + A submodule of CPUOffloadedRecMetricModule. + + The comms module's main purposes are: + 1. All gather metric state tensors + 2. Load all gathered metric states + 3. Compute metrics + + This isolation allows CPUOffloadedRecMetricModule from having + to concern about aggregated states and instead focus solely + updating local state tensors and dumping snapshots to the comms module + for metric aggregations. + """ + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + """ + All arguments are the same as RecMetricModule + """ + + super().__init__(*args, **kwargs) + + rec_metrics_clone = self._clone_rec_metrics() + self.rec_metrics: RecMetricList = rec_metrics_clone + + for metric in self.rec_metrics.rec_metrics: + # Disable automatic sync for all metrics - handled manually via + # RecMetricModule.get_pre_compute_states() + metric = cast(RecMetric, metric) + for computation in metric._metrics_computations: + computation = cast(RecMetricComputation, computation) + computation._to_sync = False + + def load_local_metric_state_snapshot( + self, state_snapshot: MetricStateSnapshot + ) -> None: + """ + Load local metric states before all gather. + MetricStateSnapshot provides already-reduced states. + + Args: + state_snapshot (MetricStateSnapshot): a snapshot of metric states to load. + """ + + # Load states into comms module to be shared across ranks. + + with record_function("## CPUCommsRecMetricModule: load_snapshot ##"): + for metric in self.rec_metrics.rec_metrics: + metric = cast(RecMetric, metric) + compute_mode = metric._compute_mode + if ( + compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION + or compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION + ): + prefix = compute_mode.name + computation = metric._metrics_computations[0] + self._load_metric_states( + prefix, computation, state_snapshot.metric_states + ) + for task, computation in zip( + metric._tasks, metric._metrics_computations + ): + self._load_metric_states( + task.name, computation, state_snapshot.metric_states + ) + + if state_snapshot.throughput_metric is not None: + self.throughput_metric = state_snapshot.throughput_metric + + def _load_metric_states( + self, prefix: str, computation: nn.Module, metric_states: Dict[str, Any] + ) -> None: + """ + Load metric states after all gather. + Uses aggregated states. + """ + + # All update() calls were done prior. Clear previous computed state. + # Otherwise, we get warnings that compute() was called before + # update() which is not the case. + computation = cast(RecMetricComputation, computation) + set_update_called(computation) + computation._computed = None + + computation_name = f"{prefix}_{computation.__class__.__name__}" + # Restore all cached states from reductions + for attr_name in computation._reductions: + cache_key = f"{computation_name}_{attr_name}" + if cache_key in metric_states: + cached_value = metric_states[cache_key] + setattr(computation, attr_name, cached_value) + + def _clone_rec_metrics(self) -> RecMetricList: + """ + Clone rec_metrics. We need to keep references to the original tasks + and computation to load the state tensors. More importantly, we need to + remove the references to the original metrics to prevent concurrent access + from the update and compute threads. + """ + + cloned_metrics = [] + for metric in self.rec_metrics.rec_metrics: + metric = cast(RecMetric, metric) + cloned_metric = type(metric)( + world_size=metric._world_size, + my_rank=metric._my_rank, + batch_size=metric._batch_size, + tasks=metric._tasks, + compute_mode=metric._compute_mode, + # Standard initialization passes in the global window size. A RecMetric's + # window size is set as the local window size. + window_size=metric._window_size * metric._world_size, + fused_update_limit=metric._fused_update_limit, + compute_on_all_ranks=metric._metrics_computations[ + 0 + ]._compute_on_all_ranks, + should_validate_update=metric._should_validate_update, + # Process group should be none to prevent unwanted distributed syncs. + # This is handled manually via RecMetricModule.get_pre_compute_states() + process_group=None, + ) + cloned_metrics.append(cloned_metric) + + return RecMetricList(cloned_metrics) + + +def set_update_called(computation: RecMetricComputation) -> None: + """ + Set _update_called to True for RecMetricComputation. + This is a workaround for torchmetrics 1.0.3+. + """ + try: + computation._update_called = True + except AttributeError: + # pyre-ignore + computation._update_count = 1 diff --git a/torchrec/metrics/tests/test_cpu_comms_metric_module.py b/torchrec/metrics/tests/test_cpu_comms_metric_module.py new file mode 100644 index 000000000..5ec934879 --- /dev/null +++ b/torchrec/metrics/tests/test_cpu_comms_metric_module.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import cast, List + +import torch +from torchrec.metrics.auc import _state_reduction +from torchrec.metrics.cpu_comms_metric_module import CPUCommsRecMetricModule +from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecMetricList +from torchrec.metrics.test_utils import gen_test_tasks +from torchrec.metrics.test_utils.mock_metrics import ( + assert_tensor_dict_equals, + create_metric_states_dict, + create_tensor_states, + MockRecMetric, +) + + +class CPUCommsRecMetricModuleTest(unittest.TestCase): + """ + Tests cloning rec metrics and loading snapshots into CPUCommsRecMetricModule. + """ + + def setUp(self) -> None: + self.world_size = 2 + self.batch_size = 4 + self.my_rank = 0 + self.tasks = gen_test_tasks(["test_task"]) + + def test_clone_rec_metrics_reference(self) -> None: + """Tests cloned rec metrics upon initialization is a deep copy of the original.""" + + mock_metric_1 = MockRecMetric( + world_size=self.world_size, + my_rank=self.my_rank, + batch_size=self.batch_size, + tasks=self.tasks, + ) + + mock_metric_2 = MockRecMetric( + world_size=self.world_size, + my_rank=self.my_rank, + batch_size=self.batch_size, + tasks=self.tasks, + ) + + rec_metrics = RecMetricList([mock_metric_1, mock_metric_2]) + + cpu_comms_module = CPUCommsRecMetricModule( + batch_size=self.batch_size, + world_size=self.world_size, + rec_tasks=self.tasks, + rec_metrics=rec_metrics, + ) + + original_metrics = rec_metrics.rec_metrics + cloned_metrics = cpu_comms_module.rec_metrics.rec_metrics + + self.assertEqual(len(original_metrics), len(cloned_metrics)) + for original_metric, cloned_metric in zip(original_metrics, cloned_metrics): + original_metric = cast(MockRecMetric, original_metric) + cloned_metric = cast(MockRecMetric, cloned_metric) + + # Verify basic properties are preserved + self.assertEqual(original_metric._world_size, cloned_metric._world_size) + self.assertEqual(original_metric._my_rank, cloned_metric._my_rank) + self.assertEqual(original_metric._batch_size, cloned_metric._batch_size) + self.assertEqual(original_metric._compute_mode, cloned_metric._compute_mode) + + # State tensor names must be the same in order to load into the correct + # state tensors. + original_metric_states = set( + original_metric.get_computation_states().keys() + ) + cloned_metric_states = set(cloned_metric.get_computation_states().keys()) + self.assertSetEqual(original_metric_states, cloned_metric_states) + + # Cloned metric should have torchmetric.Metric's sync() disabled to prevent + # unwanted distributed syncs. All syncs will be called via cpu_comms_module. + self.assertTrue(cloned_metric.verify_sync_disabled()) + + def test_load_metric_states(self) -> None: + """ + Test loading metric states into a single metric computation. + """ + + initial_states = { + "state_1": torch.tensor(1.0), + "state_2": torch.tensor(2.0), + "state_3": torch.tensor(3.0), + } + mock_metric = MockRecMetric( + world_size=self.world_size, + my_rank=self.my_rank, + batch_size=self.batch_size, + tasks=self.tasks, + initial_states=initial_states, + ) + + cpu_comms_module = CPUCommsRecMetricModule( + batch_size=self.batch_size, + world_size=self.world_size, + rec_tasks=self.tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + metric_states = create_metric_states_dict( + metric_prefix="test_prefix", + computation_name="MockRecMetricComputation", + metric_states={ + **initial_states, + "ignored_key": torch.tensor(15.0), + }, + ) + + cloned_metric = cpu_comms_module.rec_metrics.rec_metrics[0] + cloned_computation = cloned_metric._metrics_computations[0] + + cpu_comms_module._load_metric_states( + "test_prefix", cloned_computation, metric_states + ) + + self.assertTrue(cloned_computation._update_called) + self.assertIsNone(cloned_computation._computed) + assert_tensor_dict_equals( + cloned_metric.get_computation_states(), + initial_states, + ) + + def test_snapshot_generation(self) -> None: + """Test that original metrics and comms module loaded metrics produce the same snapshot.""" + + original_states = { + "state_1": torch.tensor(7.5), + "state_2": torch.tensor(12.0), + "state_3": torch.tensor(49.0), + } + + original_metric = MockRecMetric( + world_size=self.world_size, + my_rank=self.my_rank, + batch_size=self.batch_size, + tasks=self.tasks, + initial_states=original_states, + ) + + rec_metrics = RecMetricList([original_metric]) + original_snapshot = MetricStateSnapshot.from_metrics(rec_metrics) + + cpu_comms_module = CPUCommsRecMetricModule( + batch_size=self.batch_size, + world_size=self.world_size, + rec_tasks=self.tasks, + rec_metrics=rec_metrics, + ) + cpu_comms_module.load_local_metric_state_snapshot(original_snapshot) + loaded_snapshot = MetricStateSnapshot.from_metrics(cpu_comms_module.rec_metrics) + + assert_tensor_dict_equals( + original_snapshot.metric_states, + loaded_snapshot.metric_states, + ) + + def test_load_metric_states_partial_load(self) -> None: + """Test loading metric states when some keys are missing from the snapshot.""" + + initial_states = { + "state_1": torch.tensor(1.0), + "state_2": torch.tensor(2.0), + "state_3": torch.tensor(3.0), + } + mock_metric = MockRecMetric( + world_size=self.world_size, + my_rank=self.my_rank, + batch_size=self.batch_size, + tasks=self.tasks, + initial_states=initial_states, + ) + + cpu_comms_module = CPUCommsRecMetricModule( + batch_size=self.batch_size, + world_size=self.world_size, + rec_tasks=self.tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + # Metric states only contains one of the initial keys + metric_states = create_metric_states_dict( + metric_prefix="test_prefix", + computation_name="MockRecMetricComputation", + metric_states={"state_1": torch.tensor(5.0)}, + ) + + cloned_metric = cpu_comms_module.rec_metrics.rec_metrics[0] + cloned_computation = cloned_metric._metrics_computations[0] + + cpu_comms_module._load_metric_states( + "test_prefix", cloned_computation, metric_states + ) + + torch.testing.assert_close( + cloned_metric.get_computation_states()["state_1"], torch.tensor(5.0) + ) + self.assertFalse( + torch.allclose( + cloned_metric.get_computation_states()["state_2"], torch.tensor(2.0) + ) + ) + self.assertFalse( + torch.allclose( + cloned_metric.get_computation_states()["state_2"], torch.tensor(3.0) + ) + ) + + def test_load_multiple_metrics_unfused(self) -> None: + """Test handling multiple metrics and tasks together.""" + + ne_tasks = gen_test_tasks(["task1", "task2", "task3"]) + auc_tasks = gen_test_tasks(["task4", "task5", "task6"]) + + ne_states = create_tensor_states(["state_1", "state_2", "state_3"], n_tasks=1) + mock_nes = [ + MockRecMetric( + world_size=self.world_size, + my_rank=self.my_rank, + batch_size=self.batch_size, + tasks=[task], + initial_states=ne_states, + ) + for task in ne_tasks + ] + + auc_states = { + "state_1": [torch.tensor([[1.0, 2.0]])], + "state_2": [torch.tensor([[0.0, 1.0]])], + "state_3": [torch.tensor([[4.0, 1.0]])], + } + mock_aucs = [ + MockRecMetric( + world_size=self.world_size, + my_rank=self.my_rank, + batch_size=self.batch_size, + tasks=[task], + reduction_fn=_state_reduction, + initial_states=auc_states, + is_tensor_list=True, + ) + for task in auc_tasks + ] + + rec_metrics_list: List[RecMetric] = [*mock_nes, *mock_aucs] + rec_metrics = RecMetricList(rec_metrics_list) + cpu_comms_module = CPUCommsRecMetricModule( + batch_size=self.batch_size, + world_size=self.world_size, + rec_tasks=ne_tasks + auc_tasks, + rec_metrics=rec_metrics, + ) + + snapshot = MetricStateSnapshot.from_metrics(rec_metrics) + cpu_comms_module.load_local_metric_state_snapshot(snapshot) + + ne_states_dict = {} + for task in ne_tasks: + ne_states_dict.update( + create_metric_states_dict( + metric_prefix=task.name, + computation_name="MockRecMetricComputation", + metric_states=ne_states, + ) + ) + + auc_states_dict = {} + for task in auc_tasks: + auc_states_dict.update( + create_metric_states_dict( + metric_prefix=task.name, + computation_name="MockRecMetricComputation", + metric_states=auc_states, + ) + ) + + expected_metric_states = {**ne_states_dict, **auc_states_dict} + self.assertEqual(len(cpu_comms_module.rec_metrics.rec_metrics), 6) + actual_metric_states_dict = {} + for task, metric in zip( + cpu_comms_module.rec_tasks, cpu_comms_module.rec_metrics.rec_metrics + ): + metric = cast(MockRecMetric, metric) + actual_metric_states_dict.update( + create_metric_states_dict( + metric_prefix=task.name, + computation_name="MockRecMetricComputation", + metric_states=metric.get_computation_states(), + ) + ) + assert_tensor_dict_equals( + actual_metric_states_dict, + expected_metric_states, + ) + + def test_load_multiple_metrics_fused(self) -> None: + """Test handling multiple metrics and tasks together.""" + + ne_tasks = gen_test_tasks(["task1", "task2", "task3"]) + + ne_states = create_tensor_states(["state_1", "state_2", "state_3"], n_tasks=3) + mock_ne = MockRecMetric( + world_size=self.world_size, + my_rank=self.my_rank, + batch_size=self.batch_size, + tasks=ne_tasks, + initial_states=ne_states, + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + rec_metrics = RecMetricList([mock_ne]) + cpu_comms_module = CPUCommsRecMetricModule( + batch_size=self.batch_size, + world_size=self.world_size, + rec_tasks=ne_tasks, + rec_metrics=rec_metrics, + ) + + snapshot = MetricStateSnapshot.from_metrics(rec_metrics) + cpu_comms_module.load_local_metric_state_snapshot(snapshot) + + self.assertEqual(len(cpu_comms_module.rec_metrics.rec_metrics), 1) + metric = cpu_comms_module.rec_metrics.rec_metrics[0] + metric = cast(MockRecMetric, metric) + assert_tensor_dict_equals( + metric.get_computation_states(), + ne_states, + ) + + +if __name__ == "__main__": + unittest.main()