From 7086082d629cb446eb76d264862c9902816d5c0b Mon Sep 17 00:00:00 2001 From: Levy Zhao Date: Tue, 14 May 2024 08:44:22 -0700 Subject: [PATCH] Support multi-pass prefetch config in CacheParams, and shard estimator will correctly propagate its storage cost (#2000) Summary: After FBGEMM TBE support multipass prefetch mode (see https://github.com/pytorch/FBGEMM/pull/2566 for the full context), this diff will enable TorchRec to pass it all through via CacheParams, and shard estimator will recognize the memory saving accordingly. Differential Revision: D57055184 --- .../distributed/planner/shard_estimators.py | 22 +++++++++++++++-- .../planner/tests/test_shard_estimators.py | 24 +++++++++++++++++++ torchrec/distributed/tests/test_utils.py | 5 ++++ torchrec/distributed/types.py | 3 +++ torchrec/distributed/utils.py | 4 ++++ 5 files changed, 56 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index ba6eafc01..a794590bf 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -902,6 +902,19 @@ def estimate( else sharding_option.tensor.element_size() ) + mpp_conf = ( + sharding_option.cache_params.multipass_prefetch_config + if sharding_option.cache_params + else None + ) + # TODO: remove after deprecating fused_params in sharder + if mpp_conf is None: + mpp_conf = ( + sharder.fused_params.get("multipass_prefetch_config", None) + if hasattr(sharder, "fused_params") and sharder.fused_params + else None + ) + shard_storages = calculate_shard_storages( sharder=sharder, sharding_type=sharding_option.sharding_type, @@ -920,6 +933,7 @@ def estimate( output_data_type_size=output_data_type_size, pipeline_type=self._pipeline_type, is_inference=self._is_inference, + multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None, ) for shard, storage in zip(sharding_option.shards, shard_storages): @@ -931,6 +945,7 @@ def calculate_pipeline_io_cost( output_size: int, prefetch_size: int, pipeline_type: PipelineType, + multipass_prefetch_max_pass: Optional[int], is_inference: bool = False, ) -> int: # These magical number comes from heuristical analysis of memory snapshot during @@ -944,11 +959,12 @@ def calculate_pipeline_io_cost( pipelining_hbm_input_factor = 2 return max(pipelining_hbm_input_factor * input_size, output_size) if pipeline_type == PipelineType.TRAIN_PREFETCH_SPARSE_DIST: + multipass_prefetch_max_pass = multipass_prefetch_max_pass or 1 pipelining_hbm_input_factor = 4 - prefetch_bursty_hbm_input_factor = 7 + prefetch_bursty_hbm_input_factor = 1 + 6 / multipass_prefetch_max_pass return max( pipelining_hbm_input_factor * input_size - + prefetch_bursty_hbm_input_factor * prefetch_size, + + int(prefetch_bursty_hbm_input_factor * prefetch_size), output_size, ) @@ -974,6 +990,7 @@ def calculate_shard_storages( output_data_type_size: float, pipeline_type: PipelineType = PipelineType.NONE, is_inference: bool = False, + multipass_prefetch_max_pass: Optional[int] = None, ) -> List[Storage]: """ Calculates estimated storage sizes for each sharded tensor, comprised of input, @@ -1057,6 +1074,7 @@ def calculate_shard_storages( output_size=output_size, prefetch_size=input_size if table_cached else 0, pipeline_type=pipeline_type, + multipass_prefetch_max_pass=multipass_prefetch_max_pass, is_inference=is_inference, ) if compute_device == "cuda" diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index d170eac92..7d928aa0c 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -40,6 +40,7 @@ CacheParams, CacheStatistics, ModuleSharder, + MultiPassPrefetchConfig, PipelineType, ShardingType, ) @@ -593,6 +594,12 @@ def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None: name="table_1", feature_names=["feature_1"], ), + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_2", + feature_names=["feature_2"], + ), ] constraints = { "table_0": ParameterConstraints( @@ -610,6 +617,16 @@ def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None: load_factor=None, ), ), + "table_2": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + sharding_types=[ShardingType.TABLE_WISE.value], + cache_params=CacheParams( + load_factor=0.1, + multipass_prefetch_config=MultiPassPrefetchConfig( + num_passes=10, + ), + ), + ), } enumerator = EmbeddingEnumerator( topology=topology, @@ -637,6 +654,7 @@ def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None: expected_storage = { ("table_0", "fused_uvm_caching", "table_wise"): [(100 + 3333, 100)], ("table_1", "fused", "table_wise"): [(100 + 3333, 100)], + ("table_2", "fused_uvm_caching", "table_wise"): [(100 + 3333, 100)], } elif pipeline_type == PipelineType.TRAIN_PREFETCH_SPARSE_DIST: expected_storage = { @@ -644,6 +662,9 @@ def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None: (100 + 1024 * 11, 100) ], ("table_1", "fused", "table_wise"): [(100 + 1024 * 4, 100)], + ("table_2", "fused_uvm_caching", "table_wise"): [ + (100 + 1024 * 4 + int(1024 * 1.6), 100) + ], } else: expected_storage = { @@ -651,6 +672,9 @@ def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None: (100 + 3333 + 1024, 100) ], ("table_1", "fused", "table_wise"): [(100 + 3333 + 1024, 100)], + ("table_2", "fused_uvm_caching", "table_wise"): [ + (100 + 3333 + 1024, 100) + ], } actual_storage = { ( diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index 8474fab07..135d7f47a 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -27,6 +27,7 @@ CacheParams, DataType, ModuleSharder, + MultiPassPrefetchConfig, ParameterSharding, ) from torchrec.distributed.utils import ( @@ -479,6 +480,7 @@ def setUp(self) -> None: algorithm=CacheAlgorithm.LFU, reserved_memory=1.0, prefetch_pipeline=False, + multipass_prefetch_config=MultiPassPrefetchConfig(num_passes=2), ), enforce_hbm=False, stochastic_rounding=True, @@ -497,6 +499,7 @@ def test_add_params_from_parameter_sharding(self) -> None: "enforce_hbm": False, "stochastic_rounding": True, "bounds_check_mode": BoundsCheckMode.WARNING, + "multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=2), } self.assertEqual(fused_params, expected_fused_params) @@ -506,6 +509,7 @@ def test_add_params_from_parameter_sharding_override(self) -> None: "cache_algorithm": CacheAlgorithm.LRU, "stochastic_rounding": False, "prefetch_pipeline": True, + "multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=5), } fused_params = add_params_from_parameter_sharding( fused_params, self.parameter_sharding @@ -518,6 +522,7 @@ def test_add_params_from_parameter_sharding_override(self) -> None: "enforce_hbm": False, "stochastic_rounding": True, "bounds_check_mode": BoundsCheckMode.WARNING, + "multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=2), } self.assertEqual(fused_params, expected_fused_params) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index cb1a0e2fe..31857193a 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -28,6 +28,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, CacheAlgorithm, + MultiPassPrefetchConfig, ) from torch.autograd.profiler import record_function @@ -558,6 +559,7 @@ class CacheParams: precision: Optional[DataType] = None prefetch_pipeline: Optional[bool] = None stats: Optional[CacheStatistics] = None + multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None def __hash__(self) -> int: return hash( @@ -567,6 +569,7 @@ def __hash__(self) -> int: self.reserved_memory, self.precision, self.prefetch_pipeline, + self.multipass_prefetch_config, ) ) diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 37c13f437..58d83cc00 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -380,6 +380,10 @@ def add_params_from_parameter_sharding( fused_params["cache_precision"] = cache_params.precision if cache_params.prefetch_pipeline is not None: fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline + if cache_params.multipass_prefetch_config is not None: + fused_params["multipass_prefetch_config"] = ( + cache_params.multipass_prefetch_config + ) if parameter_sharding.enforce_hbm is not None: fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm