Skip to content

Commit

Permalink
Support multi-pass prefetch config in CacheParams, and shard estimato…
Browse files Browse the repository at this point in the history
…r will correctly propagate its storage cost (pytorch#2000)

Summary:

After FBGEMM TBE support multipass prefetch mode (see pytorch/FBGEMM#2566 for the full context), this diff will enable TorchRec to pass it all through via CacheParams, and shard estimator will recognize the memory saving accordingly.

Differential Revision: D57055184
  • Loading branch information
levythu authored and facebook-github-bot committed May 14, 2024
1 parent 7924c6f commit 7086082
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 2 deletions.
22 changes: 20 additions & 2 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,19 @@ def estimate(
else sharding_option.tensor.element_size()
)

mpp_conf = (
sharding_option.cache_params.multipass_prefetch_config
if sharding_option.cache_params
else None
)
# TODO: remove after deprecating fused_params in sharder
if mpp_conf is None:
mpp_conf = (
sharder.fused_params.get("multipass_prefetch_config", None)
if hasattr(sharder, "fused_params") and sharder.fused_params
else None
)

shard_storages = calculate_shard_storages(
sharder=sharder,
sharding_type=sharding_option.sharding_type,
Expand All @@ -920,6 +933,7 @@ def estimate(
output_data_type_size=output_data_type_size,
pipeline_type=self._pipeline_type,
is_inference=self._is_inference,
multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None,
)

for shard, storage in zip(sharding_option.shards, shard_storages):
Expand All @@ -931,6 +945,7 @@ def calculate_pipeline_io_cost(
output_size: int,
prefetch_size: int,
pipeline_type: PipelineType,
multipass_prefetch_max_pass: Optional[int],
is_inference: bool = False,
) -> int:
# These magical number comes from heuristical analysis of memory snapshot during
Expand All @@ -944,11 +959,12 @@ def calculate_pipeline_io_cost(
pipelining_hbm_input_factor = 2
return max(pipelining_hbm_input_factor * input_size, output_size)
if pipeline_type == PipelineType.TRAIN_PREFETCH_SPARSE_DIST:
multipass_prefetch_max_pass = multipass_prefetch_max_pass or 1
pipelining_hbm_input_factor = 4
prefetch_bursty_hbm_input_factor = 7
prefetch_bursty_hbm_input_factor = 1 + 6 / multipass_prefetch_max_pass
return max(
pipelining_hbm_input_factor * input_size
+ prefetch_bursty_hbm_input_factor * prefetch_size,
+ int(prefetch_bursty_hbm_input_factor * prefetch_size),
output_size,
)

Expand All @@ -974,6 +990,7 @@ def calculate_shard_storages(
output_data_type_size: float,
pipeline_type: PipelineType = PipelineType.NONE,
is_inference: bool = False,
multipass_prefetch_max_pass: Optional[int] = None,
) -> List[Storage]:
"""
Calculates estimated storage sizes for each sharded tensor, comprised of input,
Expand Down Expand Up @@ -1057,6 +1074,7 @@ def calculate_shard_storages(
output_size=output_size,
prefetch_size=input_size if table_cached else 0,
pipeline_type=pipeline_type,
multipass_prefetch_max_pass=multipass_prefetch_max_pass,
is_inference=is_inference,
)
if compute_device == "cuda"
Expand Down
24 changes: 24 additions & 0 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
CacheParams,
CacheStatistics,
ModuleSharder,
MultiPassPrefetchConfig,
PipelineType,
ShardingType,
)
Expand Down Expand Up @@ -593,6 +594,12 @@ def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None:
name="table_1",
feature_names=["feature_1"],
),
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=10,
name="table_2",
feature_names=["feature_2"],
),
]
constraints = {
"table_0": ParameterConstraints(
Expand All @@ -610,6 +617,16 @@ def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None:
load_factor=None,
),
),
"table_2": ParameterConstraints(
compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value],
sharding_types=[ShardingType.TABLE_WISE.value],
cache_params=CacheParams(
load_factor=0.1,
multipass_prefetch_config=MultiPassPrefetchConfig(
num_passes=10,
),
),
),
}
enumerator = EmbeddingEnumerator(
topology=topology,
Expand Down Expand Up @@ -637,20 +654,27 @@ def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None:
expected_storage = {
("table_0", "fused_uvm_caching", "table_wise"): [(100 + 3333, 100)],
("table_1", "fused", "table_wise"): [(100 + 3333, 100)],
("table_2", "fused_uvm_caching", "table_wise"): [(100 + 3333, 100)],
}
elif pipeline_type == PipelineType.TRAIN_PREFETCH_SPARSE_DIST:
expected_storage = {
("table_0", "fused_uvm_caching", "table_wise"): [
(100 + 1024 * 11, 100)
],
("table_1", "fused", "table_wise"): [(100 + 1024 * 4, 100)],
("table_2", "fused_uvm_caching", "table_wise"): [
(100 + 1024 * 4 + int(1024 * 1.6), 100)
],
}
else:
expected_storage = {
("table_0", "fused_uvm_caching", "table_wise"): [
(100 + 3333 + 1024, 100)
],
("table_1", "fused", "table_wise"): [(100 + 3333 + 1024, 100)],
("table_2", "fused_uvm_caching", "table_wise"): [
(100 + 3333 + 1024, 100)
],
}
actual_storage = {
(
Expand Down
5 changes: 5 additions & 0 deletions torchrec/distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
CacheParams,
DataType,
ModuleSharder,
MultiPassPrefetchConfig,
ParameterSharding,
)
from torchrec.distributed.utils import (
Expand Down Expand Up @@ -479,6 +480,7 @@ def setUp(self) -> None:
algorithm=CacheAlgorithm.LFU,
reserved_memory=1.0,
prefetch_pipeline=False,
multipass_prefetch_config=MultiPassPrefetchConfig(num_passes=2),
),
enforce_hbm=False,
stochastic_rounding=True,
Expand All @@ -497,6 +499,7 @@ def test_add_params_from_parameter_sharding(self) -> None:
"enforce_hbm": False,
"stochastic_rounding": True,
"bounds_check_mode": BoundsCheckMode.WARNING,
"multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=2),
}
self.assertEqual(fused_params, expected_fused_params)

Expand All @@ -506,6 +509,7 @@ def test_add_params_from_parameter_sharding_override(self) -> None:
"cache_algorithm": CacheAlgorithm.LRU,
"stochastic_rounding": False,
"prefetch_pipeline": True,
"multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=5),
}
fused_params = add_params_from_parameter_sharding(
fused_params, self.parameter_sharding
Expand All @@ -518,6 +522,7 @@ def test_add_params_from_parameter_sharding_override(self) -> None:
"enforce_hbm": False,
"stochastic_rounding": True,
"bounds_check_mode": BoundsCheckMode.WARNING,
"multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=2),
}
self.assertEqual(fused_params, expected_fused_params)

Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
CacheAlgorithm,
MultiPassPrefetchConfig,
)

from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -558,6 +559,7 @@ class CacheParams:
precision: Optional[DataType] = None
prefetch_pipeline: Optional[bool] = None
stats: Optional[CacheStatistics] = None
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None

def __hash__(self) -> int:
return hash(
Expand All @@ -567,6 +569,7 @@ def __hash__(self) -> int:
self.reserved_memory,
self.precision,
self.prefetch_pipeline,
self.multipass_prefetch_config,
)
)

Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ def add_params_from_parameter_sharding(
fused_params["cache_precision"] = cache_params.precision
if cache_params.prefetch_pipeline is not None:
fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline
if cache_params.multipass_prefetch_config is not None:
fused_params["multipass_prefetch_config"] = (
cache_params.multipass_prefetch_config
)

if parameter_sharding.enforce_hbm is not None:
fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm
Expand Down

0 comments on commit 7086082

Please sign in to comment.