From b3c23a3cccfe0ef0b9488d171d4108cc94ee129f Mon Sep 17 00:00:00 2001 From: Isuru Janith Ranawaka Date: Mon, 24 Nov 2025 15:07:06 -0800 Subject: [PATCH] Objectives with input dist latencies (#3575) Summary: Pull Request resolved: https://github.com/meta-pytorch/torchrec/pull/3575 This diff introduce two objectives for LP Planner considering input_dist in Critical Path. - BALANCE_ACROSS_ALL_SYNC_POINTS_WITH_INPUT_DIST max(fwd compute) + max(bwd compute) + sum_{module, shardtype} max(fwd comms for module) + max(bwd comms for module, shardtype) + sum_{module, shardtype} max(bwd comms for module, shardtype} + sum_{module} max(input_dist_comms for module) - BALANCE_ACROSS_ALL_SYNC_POINTS_WITH_COMBINED_FWD_COMMS_INPUT_DIST max(fwd compute) + max(bwd compute) + sum_{module, shardtype} max(fwd comms + input_dist_comms for module) + max(bwd comms for module, shardtype) + sum_{module, shardtype} max(bwd comms for module, shardtype} Differential Revision: D87389540 --- torchrec/distributed/planner/constants.py | 3 + .../distributed/planner/shard_estimators.py | 74 +++++++++++++++++++ .../planner/tests/test_shard_estimators.py | 55 ++++++++++++++ torchrec/distributed/planner/types.py | 3 + 4 files changed, 135 insertions(+) diff --git a/torchrec/distributed/planner/constants.py b/torchrec/distributed/planner/constants.py index 56c7dc26f..324beaefc 100644 --- a/torchrec/distributed/planner/constants.py +++ b/torchrec/distributed/planner/constants.py @@ -42,6 +42,9 @@ WEIGHTED_KERNEL_MULTIPLIER: float = 1.1 # empirical studies DP_ELEMENTWISE_KERNELS_PERF_FACTOR: float = 9.22 # empirical studies +# TODO: This can be hardware dependent, need more empirical results to verify +A2A_INVERSE_BANDWITH_COEFFICIENT: float = 1 # empirical studies + def kernel_bw_lookup( compute_device: str, diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 3650324fc..b76108bdb 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -18,6 +18,7 @@ from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.planner.constants import ( + A2A_INVERSE_BANDWITH_COEFFICIENT, BATCHED_COPY_PERF_FACTOR, BIGINT_DTYPE, DP_ELEMENTWISE_KERNELS_PERF_FACTOR, @@ -461,6 +462,49 @@ def _get_expected_cache_prefetch_time( prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size return prefetch_bytes / hbm_to_ddr_mem_bw + @classmethod + def _input_dist_expected_latency( + cls, + batch_sizes: List[int], + world_size: int, + local_world_size: int, + num_poolings: List[float], + input_lengths: List[float], + fwd_a2a_comm_data_type_size: float, + comms_bandwidths: GeneralizedCommsBandwidth, + ) -> float: + """ + Calculates the expected latency for A2A input dist. + + Args: + batch_sizes (int): The batch size for each input feature. + world_size (int): The total number of devices in the distributed setup. + local_world_size (int): The number of devices on a single host. + num_poolings (List[float]): Number of poolings per sample for each input feature. + input_lengths (List[float]): Average number of lookups per input feature. + fwd_a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication. + comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths. + + Returns: + float: The expected latency (in seconds) for input distribution. + """ + batch_inputs = sum( + [x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)] + ) + input_read_size = math.ceil( + batch_inputs * world_size * fwd_a2a_comm_data_type_size + ) + + comms_bw = comms_bandwidths.get_bw( + world_size=world_size, + local_world_size=local_world_size, + collective_type=CollectiveType.ALL_TO_ALL, + ) + message_bw = input_read_size / comms_bw + input_dist_latency = message_bw * A2A_INVERSE_BANDWITH_COEFFICIENT + + return input_dist_latency + @classmethod def _get_tw_sharding_perf( cls, @@ -551,6 +595,15 @@ def _get_tw_sharding_perf( hbm_to_ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size ) + input_dist_comms = cls._input_dist_expected_latency( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + num_poolings=num_poolings, + input_lengths=input_lengths, + fwd_a2a_comm_data_type_size=input_data_type_size, + comms_bandwidths=comms_bandwidths, + ) # in order of model parallel execution, starting with: # BWD DP -> BWD MP ... FWD MP -> FWD DP return Perf( @@ -559,6 +612,7 @@ def _get_tw_sharding_perf( bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, bwd_comms=bwd_comms, prefetch_compute=prefetch_compute, + input_dist_comms=input_dist_comms, ) @classmethod @@ -658,6 +712,15 @@ def _get_rw_sharding_perf( emb_dim, table_data_type_size, ) + input_dist_comms = cls._input_dist_expected_latency( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + num_poolings=num_poolings, + input_lengths=input_lengths, + fwd_a2a_comm_data_type_size=input_data_type_size, + comms_bandwidths=comms_bandwidths, + ) return Perf( fwd_compute=fwd_compute, @@ -665,6 +728,7 @@ def _get_rw_sharding_perf( bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, bwd_comms=bwd_comms + bwd_batched_copy, prefetch_compute=prefetch_compute, + input_dist_comms=input_dist_comms, ) @classmethod @@ -790,6 +854,15 @@ def _get_twrw_sharding_perf( emb_dim, table_data_type_size, ) + input_dist_comms = cls._input_dist_expected_latency( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + num_poolings=num_poolings, + input_lengths=input_lengths, + fwd_a2a_comm_data_type_size=input_data_type_size, + comms_bandwidths=comms_bandwidths, + ) return Perf( fwd_compute=fwd_compute, @@ -797,6 +870,7 @@ def _get_twrw_sharding_perf( bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, bwd_comms=bwd_comms + bwd_batched_copy, prefetch_compute=prefetch_compute, + input_dist_comms=input_dist_comms, ) @classmethod diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index a2c7ed5e6..bdb278bc0 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -141,6 +141,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm", "table_wise"): [ @@ -149,6 +150,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm_caching", "table_wise"): [ @@ -157,6 +159,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused", "column_wise"): [ @@ -165,6 +168,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm", "column_wise"): [ @@ -173,6 +177,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm_caching", "column_wise"): [ @@ -181,6 +186,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused", "table_column_wise"): [ @@ -189,6 +195,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm", "table_column_wise"): [ @@ -197,6 +204,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm_caching", "table_column_wise"): [ @@ -205,6 +213,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused", "row_wise"): [ @@ -213,12 +222,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm", "row_wise"): [ @@ -227,12 +238,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm_caching", "row_wise"): [ @@ -241,12 +254,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused", "table_row_wise"): [ @@ -255,12 +270,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm", "table_row_wise"): [ @@ -269,12 +286,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm_caching", "table_row_wise"): [ @@ -283,12 +302,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), ], # grid_shard is the same as table_row_wise @@ -298,12 +319,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm", "grid_shard"): [ @@ -312,12 +335,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm_caching", "grid_shard"): [ @@ -326,12 +351,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), ], } @@ -860,6 +887,7 @@ def test_1_table_perf(self) -> None: bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05 * 2, # bw is set to half in this test + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm", "table_wise"): [ @@ -868,6 +896,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm_caching", "table_wise"): [ @@ -876,6 +905,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused", "column_wise"): [ @@ -884,6 +914,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm", "column_wise"): [ @@ -892,6 +923,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm_caching", "column_wise"): [ @@ -900,6 +932,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused", "table_column_wise"): [ @@ -908,6 +941,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm", "table_column_wise"): [ @@ -916,6 +950,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm_caching", "table_column_wise"): [ @@ -924,6 +959,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused", "row_wise"): [ @@ -932,12 +968,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm", "row_wise"): [ @@ -946,12 +984,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm_caching", "row_wise"): [ @@ -960,12 +1000,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), ], ("fused", "table_row_wise"): [ @@ -974,12 +1016,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm", "table_row_wise"): [ @@ -988,12 +1032,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm_caching", "table_row_wise"): [ @@ -1002,12 +1048,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), ], # grid_shard is the same as table_row_wise @@ -1017,12 +1065,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm", "grid_shard"): [ @@ -1031,12 +1081,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm_caching", "grid_shard"): [ @@ -1045,12 +1097,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), ], } @@ -1070,6 +1124,7 @@ def test_1_table_perf(self) -> None: ): [shard.perf for shard in sharding_option.shards] for sharding_option in sharding_options2 } + self.assertEqual(expected_perfs, perfs) self.assertEqual(expected_perfs, perfs2) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 32a750290..4891bbc53 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -54,6 +54,7 @@ class Perf: fwd_comms: float bwd_compute: float bwd_comms: float + input_dist_comms: float = 0.0 prefetch_compute: float = 0.0 @property @@ -87,6 +88,7 @@ def __add__(self, other: "Perf") -> "Perf": fwd_comms=self.fwd_comms + other.fwd_comms, bwd_compute=self.bwd_compute + other.bwd_compute, bwd_comms=self.bwd_comms + other.bwd_comms, + input_dist_comms=self.input_dist_comms + other.input_dist_comms, prefetch_compute=self.prefetch_compute + other.prefetch_compute, ) @@ -97,6 +99,7 @@ def __hash__(self) -> int: self.fwd_comms, self.bwd_compute, self.bwd_comms, + self.input_dist_comms, self.prefetch_compute, ) )