From 01e3020f2dea14dad7167992dba392f34d0ca4ca Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Thu, 12 Jun 2025 17:47:49 -0700 Subject: [PATCH] Enhance ParameterCosntraint configuration in the becnhmarking script (#3082) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3082 Updated the `ParameterConstraints` in the TorchRec benchmarking script to include pooling factors, number of poolings, and batch sizes. This enhancement allows for more detailed configuration of parameter constraints for planner Differential Revision: D76440004 --- .../benchmark/benchmark_train_sparsenn.py | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py index df6ebd917..09fc5f665 100644 --- a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py +++ b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py @@ -11,7 +11,7 @@ import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union import click @@ -26,6 +26,7 @@ from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner from torchrec.distributed.planner.types import ParameterConstraints @@ -80,6 +81,9 @@ class RunOptions: planner_type (str): Type of sharding planner to use. Options are: - "embedding": EmbeddingShardingPlanner (default) - "hetero": HeteroEmbeddingShardingPlanner + pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table. + This is the average number of values each sample has for the feature. + num_poolings (Optional[List[float]]): Number of poolings for each feature of the table. """ world_size: int = 2 @@ -89,6 +93,8 @@ class RunOptions: input_type: str = "kjt" profile: str = "" planner_type: str = "embedding" + pooling_factors: Optional[List[float]] = None + num_poolings: Optional[List[float]] = None @dataclass @@ -111,7 +117,7 @@ class EmbeddingTablesConfig: num_unweighted_features: int = 100 num_weighted_features: int = 100 - embedding_feature_dim: int = 512 + embedding_feature_dim: int = 128 def generate_tables( self, @@ -286,17 +292,36 @@ def _generate_planner( tables: Optional[List[EmbeddingBagConfig]], weighted_tables: Optional[List[EmbeddingBagConfig]], sharding_type: ShardingType, - compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED, + compute_kernel: EmbeddingComputeKernel, + num_batches: int, + batch_size: int, + pooling_factors: Optional[List[float]], + num_poolings: Optional[List[float]], ) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: # Create parameter constraints for tables constraints = {} + if pooling_factors is None: + pooling_factors = [POOLING_FACTOR] * num_batches + + if num_poolings is None: + num_poolings = [NUM_POOLINGS] * num_batches + + batch_sizes = [batch_size] * num_batches + + assert ( + len(pooling_factors) == num_batches and len(num_poolings) == num_batches + ), "The length of pooling_factors and num_poolings must match the number of batches." + if tables is not None: for table in tables: constraints[table.name] = ParameterConstraints( sharding_types=[sharding_type.value], compute_kernels=[compute_kernel.value], device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, ) if weighted_tables is not None: @@ -305,6 +330,10 @@ def _generate_planner( sharding_types=[sharding_type.value], compute_kernels=[compute_kernel.value], device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, + is_weighted=True, ) if planner_type == "embedding": @@ -413,6 +442,10 @@ def runner( weighted_tables=weighted_tables, sharding_type=run_option.sharding_type, compute_kernel=run_option.compute_kernel, + num_batches=run_option.num_batches, + batch_size=input_config.batch_size, + pooling_factors=run_option.pooling_factors, + num_poolings=run_option.num_poolings, ) sharded_model, optimizer = _generate_sharded_model_and_optimizer(