From d25c2b3d6bb354c3dde8838e1cd4cc5ab239cbf7 Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Thu, 12 Jun 2025 11:08:52 -0700 Subject: [PATCH] Add logic for fqn_to_feature_names (#3059) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3059 # This Diff Added implementation for fqn_to_feature_names method along with initial testing framework and UTs for fqn_to_feature_names # ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training Reviewed By: kausv Differential Revision: D75908963 --- .../model_tracker/model_delta_tracker.py | 80 +++- .../tests/test_model_delta_tracker.py | 364 ++++++++++++++++++ .../distributed/model_tracker/tests/utils.py | 222 +++++++++++ 3 files changed, 658 insertions(+), 8 deletions(-) create mode 100644 torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py create mode 100644 torchrec/distributed/model_tracker/tests/utils.py diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 187bd08a4..0c7c2d82f 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -6,7 +6,9 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Dict, List, Optional, Union +import logging as logger +from collections import Counter, OrderedDict +from typing import Dict, Iterable, List, Optional, Union import torch @@ -30,7 +32,7 @@ } # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. -SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection] +SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection) class ModelDeltaTracker: @@ -49,6 +51,8 @@ class ModelDeltaTracker: call. delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them. mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY. + fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None. + """ DEFAULT_CONSUMER: str = "default" @@ -59,11 +63,15 @@ def __init__( consumers: Optional[List[str]] = None, delete_on_read: bool = True, mode: TrackingMode = TrackingMode.ID_ONLY, + fqns_to_skip: Iterable[str] = (), ) -> None: self._model = model self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER] self._delete_on_read = delete_on_read self._mode = mode + self._fqn_to_feature_map: Dict[str, List[str]] = {} + self._fqns_to_skip: Iterable[str] = fqns_to_skip + self.fqn_to_feature_names() pass def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: @@ -85,14 +93,70 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: """ return {} - def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]: + def fqn_to_feature_names(self) -> Dict[str, List[str]]: """ - Returns a mapping from FQN to feature names for a given module. - - Args: - module (nn.Module): the module to retrieve feature names for. + Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model. """ - return {} + if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0: + return self._fqn_to_feature_map + + table_to_feature_names: Dict[str, List[str]] = OrderedDict() + table_to_fqn: Dict[str, str] = OrderedDict() + for fqn, named_module in self._model.named_modules(): + split_fqn = fqn.split(".") + # Skipping partial FQNs present in fqns_to_skip + # TODO: Validate if we need to support more complex patterns for skipping fqns + should_skip = False + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in split_fqn: + logger.info(f"Skipping {fqn} because it is part of fqns_to_skip") + should_skip = True + break + if should_skip: + continue + + # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. + if isinstance(named_module, SUPPORTED_MODULES): + for table_name, config in named_module._table_name_to_config.items(): + logger.info( + f"Found {table_name} for {fqn} with features {config.feature_names}" + ) + table_to_feature_names[table_name] = config.feature_names + for table_name in table_to_feature_names: + # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn" + # will incorrectly match fqn with all the table names that have the same prefix + if table_name in split_fqn: + embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "") + if table_name in table_to_fqn: + # Sanity check for validating that we don't have more then one table mapping to same fqn. + logger.warning( + f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + ) + table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {table_to_fqn}") + flatten_names = [ + name for names in table_to_feature_names.values() for name in names + ] + # TODO: Validate if there is a better way to handle duplicate feature names. + # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case. + if len(set(flatten_names)) != len(flatten_names): + counts = Counter(flatten_names) + duplicates = [item for item, count in counts.items() if count > 1] + logger.warning(f"duplicate feature names found: {duplicates}") + + fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() + for table_name in table_to_feature_names: + if table_name not in table_to_fqn: + # This is likely unexpected, where we can't locate the FQN associated with this table. + logger.warning( + f"Table {table_name} not found in {table_to_fqn}, skipping" + ) + continue + fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ + table_name + ] + self._fqn_to_feature_map = fqn_to_feature_names + return fqn_to_feature_names def clear(self, consumer: Optional[str] = None) -> None: """ diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py new file mode 100644 index 000000000..fe80c2080 --- /dev/null +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest +from dataclasses import dataclass +from typing import cast, Dict, List, Type, Union + +import torch +import torchrec +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType + +from parameterized import parameterized +from torch import nn +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding import EmbeddingCollectionSharder +from torchrec.distributed.embedding_types import ModuleSharder, ShardingType +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker +from torchrec.distributed.model_tracker.tests.utils import ( + EmbeddingTableProps, + generate_planner_constraints, + TestEBCModel, + TestECModel, +) + +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, + EmbeddingConfig, + PoolingType, +) +from torchrec.test_utils import skip_if_asan + +NUM_EMBEDDINGS: int = 16 +EMBEDDING_DIM: int = 256 + +HAS_2_GPU: bool = torch.cuda.device_count() >= 2 +HAS_1_GPU: bool = torch.cuda.device_count() >= 1 + + +# Helper function to create a model +def get_model( + rank: int, + world_size: int, + ctx: MultiProcessContext, + embedding_config_type: Union[Type[EmbeddingConfig], Type[EmbeddingBagConfig]], + embedding_tables: Dict[str, EmbeddingTableProps], +) -> DistributedModelParallel: + # Create the model + test_model = ( + TestECModel( + tables=[ + EmbeddingConfig( + name=table_name, + embedding_dim=table.embedding_dim, + num_embeddings=table.num_embeddings, + feature_names=table.feature_names, + ) + for table_name, table in embedding_tables.items() + ] + ) + if embedding_config_type == EmbeddingConfig + else TestEBCModel( + tables=[ + EmbeddingBagConfig( + name=table_name, + embedding_dim=table.embedding_dim, + num_embeddings=table.num_embeddings, + feature_names=table.feature_names, + pooling=table.pooling, + ) + for table_name, table in embedding_tables.items() + ] + ) + ) + + # Set up device + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + else: + device = torch.device("cpu") + + # Create planner and sharders + planner = EmbeddingShardingPlanner( + topology=Topology(world_size, "cuda"), + constraints=generate_planner_constraints(embedding_tables), + ) + sharders = [ + cast( + ModuleSharder[nn.Module], + EmbeddingCollectionSharder( + fused_params={ + "optimizer": OptimType.ADAM, + "beta1": 0.9, + "beta2": 0.99, + } + ), + ), + cast( + ModuleSharder[nn.Module], + EmbeddingBagCollectionSharder(fused_params={"optimizer": OptimType.ADAM}), + ), + ] + + # Create plan + plan = planner.collective_plan(test_model, sharders, ctx.pg) + + # Create DMP + if ctx.pg is None: + raise ValueError("Process group cannot be None") + + return DistributedModelParallel( + module=test_model, + device=device, + env=torchrec.distributed.ShardingEnv.from_process_group(ctx.pg), + plan=plan, + sharders=sharders, + ) + + +@dataclass +class ModelDeltaTrackerInputTestParams: + # input parameters + embedding_config_type: Union[Type[EmbeddingConfig], Type[EmbeddingBagConfig]] + embedding_tables: Dict[str, EmbeddingTableProps] + fqns_to_skip: List[str] + + +@dataclass +class FqnToFeatureNamesOutputTestParams: + # expected output parameters + expected_fqn_to_feature_names: Dict[str, List[str]] + + +class ModelDeltaTrackerTest(MultiProcessTestBase): + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, methodName="runTest") -> None: + super().__init__(methodName) + self.world_size = 2 + + @parameterized.expand( + [ + ( + "EC_model_test", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.NONE, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f4", "f5", "f6"], + pooling=PoolingType.NONE, + ), + }, + fqns_to_skip=[], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={ + "ec.embeddings.sparse_table_1": ["f1", "f2", "f3"], + "ec.embeddings.sparse_table_2": ["f4", "f5", "f6"], + }, + ), + ), + ( + "EBC_model_test", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.SUM, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f4", "f5", "f6"], + pooling=PoolingType.SUM, + ), + }, + fqns_to_skip=[], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={ + "ebc.embedding_bags.sparse_table_1": ["f1", "f2", "f3"], + "ebc.embedding_bags.sparse_table_2": ["f4", "f5", "f6"], + }, + ), + ), + ( + "EC_model_test_with_duplicate_feature_names", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.NONE, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f3", "f4", "f5"], + pooling=PoolingType.NONE, + ), + }, + fqns_to_skip=[], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={ + "ec.embeddings.sparse_table_1": ["f1", "f2", "f3"], + "ec.embeddings.sparse_table_2": ["f3", "f4", "f5"], + }, + ), + ), + ( + "fqns_to_skip_table_name", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.SUM, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f4", "f5", "f6"], + pooling=PoolingType.SUM, + ), + }, + fqns_to_skip=["sparse_table_1"], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={ + "ebc.embedding_bags.sparse_table_2": ["f4", "f5", "f6"], + }, + ), + ), + ( + "fqns_to_skip_mid_fqn", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.SUM, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f4", "f5", "f6"], + pooling=PoolingType.SUM, + ), + }, + fqns_to_skip=["embedding_bags"], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={}, + ), + ), + ( + "fqns_to_skip_parent_fqn", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.NONE, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f3", "f4", "f5"], + pooling=PoolingType.NONE, + ), + }, + fqns_to_skip=["ec"], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={}, + ), + ), + ] + ) + @skip_if_asan + @unittest.skipUnless(HAS_1_GPU, reason="Test requires at least 1 GPU") + def test_fqn_to_feature_names( + self, + _test_name: str, + input_params: ModelDeltaTrackerInputTestParams, + output_params: FqnToFeatureNamesOutputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_fqn_to_feature_names, + world_size=self.world_size, + input_params=input_params, + output_params=output_params, + ) + + +def _test_fqn_to_feature_names( + rank: int, + world_size: int, + input_params: ModelDeltaTrackerInputTestParams, + output_params: FqnToFeatureNamesOutputTestParams, +) -> None: + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + # Get the model using the helper function + model = get_model( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=input_params.embedding_config_type, + embedding_tables=input_params.embedding_tables, + ) + + model_dt = ModelDeltaTracker(model, fqns_to_skip=input_params.fqns_to_skip) + actual_fqn_to_feature_names = model_dt.fqn_to_feature_names() + + unittest.TestCase().assertEqual( + actual_fqn_to_feature_names, + output_params.expected_fqn_to_feature_names, + ) diff --git a/torchrec/distributed/model_tracker/tests/utils.py b/torchrec/distributed/model_tracker/tests/utils.py new file mode 100644 index 000000000..88f6dbb64 --- /dev/null +++ b/torchrec/distributed/model_tracker/tests/utils.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + +from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.types import ShardingType +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingBagConfig, + EmbeddingConfig, + PoolingType, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +@dataclass +class EmbeddingTableProps: + """ + Properties of an embedding table. + + Args: + name (str): name of the table + num_embeddings (int): number of embeddings in the table + embedding_dim (int): dimension of each embedding + sharding_type (ShardingType): sharding type of the table + feature_names (List[str]): list of feature names associated with the table + pooling (PoolingType): pooling type of the table + data_type (DataType): data type of the table + weight_type (WeightedType): weight + """ + + num_embeddings: int + embedding_dim: int + sharding: ShardingType + feature_names: List[str] + pooling: PoolingType + data_type: DataType = DataType.FP32 + is_weighted: bool = False + + +class TestECModel(nn.Module): + """ + Test model with EmbeddingCollection and Linear layers. + + Args: + tables (List[EmbeddingConfig]): list of embedding tables + device (Optional[torch.device]): device on which buffers will be initialized + + Example: + TestECModel(tables=[EmbeddingConfig(...)]) + """ + + def __init__( + self, tables: List[EmbeddingConfig], device: Optional[torch.device] = None + ) -> None: + super().__init__() + self.ec: EmbeddingCollection = EmbeddingCollection( + tables=tables, + device=device if device else torch.device("meta"), + ) + + embedding_dim = tables[0].embedding_dim + + self.seq: nn.Sequential = nn.Sequential( + *[nn.Linear(embedding_dim, embedding_dim) for _ in range(3)] + ) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + """ + Forward pass of the TestECModel. + + Args: + features (KeyedJaggedTensor): Input features for the model. + + Returns: + torch.Tensor: Output tensor after processing through the model. + """ + + lookup_result = self.ec(features) + return self.seq(torch.cat([jt.values() for _, jt in lookup_result.items()])) + + +class TestEBCModel(nn.Module): + """ + Test model with EmbeddingBagCollection and Linear layers. + + Args: + tables (List[EmbeddingBagConfig]): list of embedding tables + device (Optional[torch.device]): device on which buffers will be initialized + + Example: + TestEBCModel(tables=[EmbeddingBagConfig(...)]) + """ + + def __init__( + self, tables: List[EmbeddingBagConfig], device: Optional[torch.device] = None + ) -> None: + super().__init__() + self.ebc: EmbeddingBagCollection + self.ebc = EmbeddingBagCollection( + tables=tables, + device=device if device else torch.device("meta"), + ) + + embedding_dim = tables[0].embedding_dim + self.seq: nn.Sequential = nn.Sequential( + *[nn.Linear(embedding_dim, embedding_dim) for _ in range(3)] + ) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + """ + Forward pass of the TestEBCModel. + + Args: + features (KeyedJaggedTensor): Input features for the model. + + Returns: + torch.Tensor: Output tensor after processing through the model. + """ + + lookup_result = self.ebc(features).to_dict() + return self.seq(torch.cat(tuple(lookup_result.values()))) + + +def create_ec_model( + tables: Dict[str, EmbeddingTableProps], + device: Optional[torch.device] = None, +) -> nn.Module: + """ + Create an EmbeddingCollection model with the given tables. + + Args: + tables (List[EmbeddingTableProps]): list of embedding tables + device (Optional[torch.device]): device on which buffers will be initialized + + Returns: + nn.Module: EmbeddingCollection model + """ + return TestECModel( + tables=[ + EmbeddingConfig( + name=name, + embedding_dim=table.embedding_dim, + num_embeddings=table.num_embeddings, + feature_names=table.feature_names, + data_type=table.data_type, + ) + for name, table in tables.items() + ], + device=device, + ) + + +def create_ebc_model( + tables: Dict[str, EmbeddingTableProps], + device: Optional[torch.device] = None, +) -> nn.Module: + """ + Create an EmbeddinBagCollection model with the given tables. + + Args: + tables (List[EmbeddingTableProps]): list of embedding tables + device (Optional[torch.device]): device on which buffers will be initialized + + Returns: + nn.Module: EmbeddingCollection model + """ + return TestEBCModel( + tables=[ + EmbeddingBagConfig( + name=name, + embedding_dim=table.embedding_dim, + num_embeddings=table.num_embeddings, + feature_names=table.feature_names, + data_type=table.data_type, + pooling=table.pooling, + ) + for name, table in tables.items() + ], + device=device, + ) + + +def generate_planner_constraints( + tables: Dict[str, EmbeddingTableProps], +) -> dict[str, ParameterConstraints]: + """ + Generate planner constraints for the given tables. + + Args: + tables (List[EmbeddingTableProps]): list of embedding tables + + Returns: + Dict[str, ParameterConstraints]: planner constraints + """ + constraints: Dict[str, ParameterConstraints] = {} + for name, table in tables.items(): + sharding_types = [table.sharding.value] + constraints[name] = ParameterConstraints( + sharding_types=sharding_types, + compute_kernels=[EmbeddingComputeKernel.FUSED.value], + feature_names=table.feature_names, + pooling_factors=[1.0], + ) + return constraints