From 004146e97b9d8c7f3a9ddea95e53645428a9f5bf Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Sun, 19 Oct 2025 09:09:12 -0700 Subject: [PATCH] create ModelSelectionConfig and fix bugs (#3467) Summary: # context * move `ModelSelectionConfig` to test_utils.py and make `create_model_config` a class method * fix previously failed test and some pyre ignore issues * bump pre-commit check version from python 3.9 to python 3.12 Differential Revision: D84755189 --- .github/workflows/pre-commit.yaml | 2 +- examples/retrieval/knn_index.py | 7 -- examples/retrieval/modules/two_tower.py | 2 +- examples/retrieval/two_tower_retrieval.py | 11 ++- examples/retrieval/two_tower_train.py | 1 - torchrec/distributed/benchmark/base.py | 2 +- .../benchmark/benchmark_train_pipeline.py | 72 ++---------------- .../benchmark/yaml/sparse_data_dist_base.yml | 2 + .../distributed/test_utils/model_config.py | 76 ++++++++++++++----- torchrec/distributed/test_utils/test_model.py | 2 +- torchrec/modules/object_pool_lookups.py | 2 +- 11 files changed, 78 insertions(+), 101 deletions(-) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 96d3402ac..82503e69e 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -18,7 +18,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.12 architecture: x64 packages: | ufmt==2.5.1 diff --git a/examples/retrieval/knn_index.py b/examples/retrieval/knn_index.py index 9db4a6d7d..9ffb0a384 100644 --- a/examples/retrieval/knn_index.py +++ b/examples/retrieval/knn_index.py @@ -21,7 +21,6 @@ def get_index( num_subquantizers: int, bits_per_code: int, device: Optional[torch.device] = None, - # pyre-ignore[11] ) -> Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ]: """ returns a FAISS IVFPQ index, placed on the device passed in @@ -39,25 +38,19 @@ def get_index( """ if device is not None and device.type == "cuda": - # pyre-fixme[16] res = faiss.StandardGpuResources() - # pyre-fixme[16] config = faiss.GpuIndexIVFPQConfig() - # pyre-ignore[16] index = faiss.GpuIndexIVFPQ( res, embedding_dim, num_centroids, num_subquantizers, bits_per_code, - # pyre-fixme[16] faiss.METRIC_L2, config, ) else: - # pyre-fixme[16] quantizer = faiss.IndexFlatL2(embedding_dim) - # pyre-fixme[16] index = faiss.IndexIVFPQ( quantizer, embedding_dim, diff --git a/examples/retrieval/modules/two_tower.py b/examples/retrieval/modules/two_tower.py index 704e02b63..3e2381db9 100644 --- a/examples/retrieval/modules/two_tower.py +++ b/examples/retrieval/modules/two_tower.py @@ -169,7 +169,6 @@ class TwoTowerRetrieval(nn.Module): def __init__( self, - # pyre-ignore[11] faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ], query_ebc: EmbeddingBagCollection, candidate_ebc: EmbeddingBagCollection, @@ -222,6 +221,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor: (batch_size, self.k), device=self.device, dtype=torch.int64 ) query_embedding = query_embedding.to(torch.float32) # required by faiss + # pyre-ignore[19] self.faiss_index.search(query_embedding, self.k, distances, candidates) # candidate lookup diff --git a/examples/retrieval/two_tower_retrieval.py b/examples/retrieval/two_tower_retrieval.py index 08a2de4b6..85f041d16 100644 --- a/examples/retrieval/two_tower_retrieval.py +++ b/examples/retrieval/two_tower_retrieval.py @@ -128,12 +128,9 @@ def infer( retrieval_sd = None if load_dir is not None: load_dir = load_dir.rstrip("/") - # pyre-ignore[16] index = faiss.index_cpu_to_gpu( - # pyre-ignore[16] faiss.StandardGpuResources(), faiss_device_idx, - # pyre-ignore[16] faiss.read_index(f"{load_dir}/faiss.index"), ) two_tower_sd = torch.load(f"{load_dir}/model.pt", weights_only=True) @@ -158,7 +155,13 @@ def infer( index.add(embeddings) retrieval_model = TwoTowerRetrieval( - index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16 + index, # pyre-ignore[6] + ebcs[0], + ebcs[1], + layer_sizes, + k, + device, + dtype=torch.float16, ) constraints = {} diff --git a/examples/retrieval/two_tower_train.py b/examples/retrieval/two_tower_train.py index 5772a9069..ae991c545 100644 --- a/examples/retrieval/two_tower_train.py +++ b/examples/retrieval/two_tower_train.py @@ -227,7 +227,6 @@ def train( model, dtype=torch.qint8, inplace=True ) torch.save(quant_model.state_dict(), f"{save_dir}/model.pt") - # pyre-ignore[16] faiss.write_index(index, f"{save_dir}/faiss.index") diff --git a/torchrec/distributed/benchmark/base.py b/torchrec/distributed/benchmark/base.py index cac1d99f9..0821fe579 100644 --- a/torchrec/distributed/benchmark/base.py +++ b/torchrec/distributed/benchmark/base.py @@ -849,7 +849,7 @@ class BenchFuncConfig: world_size: int num_profiles: int num_benchmarks: int - profile_dir: str + profile_dir: str = "" device_type: str = "cuda" pre_gpu_load: int = 0 export_stacks: bool = False diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index a130778fb..4439d4e63 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -20,8 +20,8 @@ See benchmark_pipeline_utils.py for step-by-step instructions. """ -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Type +from dataclasses import dataclass +from typing import List, Optional import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType @@ -37,8 +37,8 @@ from torchrec.distributed.test_utils.input_config import ModelInputConfig from torchrec.distributed.test_utils.model_config import ( BaseModelConfig, - create_model_config, generate_sharded_model_and_optimizer, + ModelSelectionConfig, ) from torchrec.distributed.test_utils.model_input import ModelInput @@ -49,7 +49,6 @@ from torchrec.distributed.test_utils.pipeline_config import PipelineConfig from torchrec.distributed.test_utils.sharding_config import PlannerConfig from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig -from torchrec.distributed.test_utils.test_model import TestOverArchLarge from torchrec.distributed.train_pipeline import TrainPipeline from torchrec.distributed.types import ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -94,11 +93,11 @@ class RunOptions(BenchFuncConfig): """ world_size: int = 2 + batch_size: int = 1024 * 32 + num_float_features: int = 10 num_batches: int = 10 sharding_type: ShardingType = ShardingType.TABLE_WISE input_type: str = "kjt" - name: str = "" - profile_dir: str = "" num_benchmarks: int = 5 num_profiles: int = 2 num_poolings: Optional[List[float]] = None @@ -113,39 +112,6 @@ class RunOptions(BenchFuncConfig): export_stacks: bool = False -@dataclass -class ModelSelectionConfig: - model_name: str = "test_sparse_nn" - - # Common config for all model types - batch_size: int = 1024 * 32 - batch_sizes: Optional[List[int]] = None - num_float_features: int = 10 - feature_pooling_avg: int = 10 - use_offsets: bool = False - dev_str: str = "" - long_kjt_indices: bool = True - long_kjt_offsets: bool = True - long_kjt_lengths: bool = True - pin_memory: bool = True - - # TestSparseNN specific config - embedding_groups: Optional[Dict[str, List[str]]] = None - feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None - max_feature_lengths: Optional[Dict[str, int]] = None - over_arch_clazz: Type[nn.Module] = TestOverArchLarge - postproc_module: Optional[nn.Module] = None - zch: bool = False - - # DeepFM specific config - hidden_layer_size: int = 20 - deep_fm_dimension: int = 5 - - # DLRM specific config - dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 128]) - over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1]) - - # single-rank runner def runner( rank: int, @@ -303,35 +269,9 @@ def main( pipeline_config: PipelineConfig, input_config: ModelInputConfig, planner_config: PlannerConfig, - model_config: Optional[BaseModelConfig] = None, ) -> None: tables, weighted_tables, *_ = table_config.generate_tables() - - if model_config is None: - model_config = create_model_config( - model_name=model_selection.model_name, - batch_size=model_selection.batch_size, - batch_sizes=model_selection.batch_sizes, - num_float_features=model_selection.num_float_features, - feature_pooling_avg=model_selection.feature_pooling_avg, - use_offsets=model_selection.use_offsets, - dev_str=model_selection.dev_str, - long_kjt_indices=model_selection.long_kjt_indices, - long_kjt_offsets=model_selection.long_kjt_offsets, - long_kjt_lengths=model_selection.long_kjt_lengths, - pin_memory=model_selection.pin_memory, - embedding_groups=model_selection.embedding_groups, - feature_processor_modules=model_selection.feature_processor_modules, - max_feature_lengths=model_selection.max_feature_lengths, - over_arch_clazz=model_selection.over_arch_clazz, - postproc_module=model_selection.postproc_module, - zch=model_selection.zch, - hidden_layer_size=model_selection.hidden_layer_size, - deep_fm_dimension=model_selection.deep_fm_dimension, - dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes, - over_arch_layer_sizes=model_selection.over_arch_layer_sizes, - ) - + model_config = model_selection.create_model_config() # launch trainers run_multi_process_func( func=runner, diff --git a/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml b/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml index ac2a90a1d..68f3b0a36 100644 --- a/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml +++ b/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml @@ -10,6 +10,8 @@ RunOptions: # export_stacks: True # enable this to export stack traces PipelineConfig: pipeline: "sparse" +ModelInputConfig: + feature_pooling_avg: 10 EmbeddingTablesConfig: num_unweighted_features: 100 num_weighted_features: 100 diff --git a/torchrec/distributed/test_utils/model_config.py b/torchrec/distributed/test_utils/model_config.py index f72abcd94..0341a00c6 100644 --- a/torchrec/distributed/test_utils/model_config.py +++ b/torchrec/distributed/test_utils/model_config.py @@ -18,7 +18,7 @@ import copy from abc import ABC, abstractmethod -from dataclasses import dataclass, fields +from dataclasses import dataclass, field, fields from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch @@ -31,6 +31,7 @@ from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.test_utils.test_model import ( + TestOverArchLarge, TestSparseNN, TestTowerCollectionSparseNN, TestTowerSparseNN, @@ -51,8 +52,9 @@ class BaseModelConfig(ABC): and requires each concrete implementation to provide its own generate_model method. """ - # Common parameters for all model types - num_float_features: int # we assume all model arch has a single dense feature layer + ## Common parameters for all model types, please do not set default values here + # we assume all model arch has a single dense feature layer + num_float_features: int @abstractmethod def generate_model( @@ -80,12 +82,12 @@ def generate_model( class TestSparseNNConfig(BaseModelConfig): """Configuration for TestSparseNN model.""" - embedding_groups: Optional[Dict[str, List[str]]] - feature_processor_modules: Optional[Dict[str, torch.nn.Module]] - max_feature_lengths: Optional[Dict[str, int]] - over_arch_clazz: Type[nn.Module] - postproc_module: Optional[nn.Module] - zch: bool + embedding_groups: Optional[Dict[str, List[str]]] = None + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None + max_feature_lengths: Optional[Dict[str, int]] = None + over_arch_clazz: Type[nn.Module] = TestOverArchLarge + postproc_module: Optional[nn.Module] = None + zch: bool = False def generate_model( self, @@ -113,8 +115,8 @@ def generate_model( class TestTowerSparseNNConfig(BaseModelConfig): """Configuration for TestTowerSparseNN model.""" - embedding_groups: Optional[Dict[str, List[str]]] - feature_processor_modules: Optional[Dict[str, torch.nn.Module]] + embedding_groups: Optional[Dict[str, List[str]]] = None + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None def generate_model( self, @@ -138,8 +140,8 @@ def generate_model( class TestTowerCollectionSparseNNConfig(BaseModelConfig): """Configuration for TestTowerCollectionSparseNN model.""" - embedding_groups: Optional[Dict[str, List[str]]] - feature_processor_modules: Optional[Dict[str, torch.nn.Module]] + embedding_groups: Optional[Dict[str, List[str]]] = None + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None def generate_model( self, @@ -163,8 +165,8 @@ def generate_model( class DeepFMConfig(BaseModelConfig): """Configuration for DeepFM model.""" - hidden_layer_size: int - deep_fm_dimension: int + hidden_layer_size: int = 20 + deep_fm_dimension: int = 5 def generate_model( self, @@ -189,8 +191,8 @@ def generate_model( class DLRMConfig(BaseModelConfig): """Configuration for DLRM model.""" - dense_arch_layer_sizes: List[int] - over_arch_layer_sizes: List[int] + dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 128]) + over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1]) def generate_model( self, @@ -213,7 +215,9 @@ def generate_model( # pyre-ignore[2]: Missing parameter annotation def create_model_config(model_name: str, **kwargs) -> BaseModelConfig: - + """ + deprecated function, please use ModelSelectionConfig.create_model_config instead + """ model_configs = { "test_sparse_nn": TestSparseNNConfig, "test_tower_sparse_nn": TestTowerSparseNNConfig, @@ -309,3 +313,39 @@ def generate_sharded_model_and_optimizer( optimizer = optimizer_class(dense_params, **optimizer_kwargs) return sharded_model, optimizer + + +@dataclass +class ModelSelectionConfig: + model_name: str = "test_sparse_nn" + model_config: Dict[str, Any] = field( + default_factory=lambda: {"num_float_features": 10} + ) + + def get_model_config_class(self) -> Type[BaseModelConfig]: + match self.model_name: + case "test_sparse_nn": + return TestSparseNNConfig + case "test_tower_sparse_nn": + return TestTowerSparseNNConfig + case "test_tower_collection_sparse_nn": + return TestTowerCollectionSparseNNConfig + case "deepfm": + return DeepFMConfig + case "dlrm": + return DLRMConfig + case _: + raise ValueError(f"Unknown model name: {self.model_name}") + + def create_model_config(self) -> BaseModelConfig: + config_class = self.get_model_config_class() + valid_field_names = {field.name for field in fields(config_class)} + filtered_kwargs = { + k: v for k, v in self.model_config.items() if k in valid_field_names + } + # pyre-ignore[45]: Invalid class instantiation + return config_class(**filtered_kwargs) + + def create_test_model(self, **kwargs: Any) -> nn.Module: + model_config = self.create_model_config() + return model_config.generate_model(**kwargs) diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 97e826e6b..cb7004670 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -2364,7 +2364,7 @@ def __init__( if device is None: device = torch.device("cpu") if max_sequence_length is None: - max_sequence_length = 10 + max_sequence_length = 20 if dense_arch_out_size is None: dense_arch_out_size = DENSE_LAYER_OUT_SIZE if over_arch_out_size is None: diff --git a/torchrec/modules/object_pool_lookups.py b/torchrec/modules/object_pool_lookups.py index b30358f19..5a3bd2e3b 100644 --- a/torchrec/modules/object_pool_lookups.py +++ b/torchrec/modules/object_pool_lookups.py @@ -413,7 +413,7 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: .sum(axis=1) ) key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths) - padded_values = torch.ops.fbgemm.jagged_to_padded_dense( + padded_values: torch.Tensor = torch.ops.fbgemm.jagged_to_padded_dense( values.values(), [key_offsets], [self._bit_dims],