diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 187bd08a4..8c622cf24 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -6,13 +6,16 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Dict, List, Optional, Union +import logging as logger +from collections import Counter, OrderedDict +from typing import Dict, Iterable, List, Optional import torch from torch import nn from torchrec.distributed.embedding import ShardedEmbeddingCollection from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection +from torchrec.distributed.model_tracker.delta_store import DeltaStore from torchrec.distributed.model_tracker.types import ( DeltaRows, EmbdUpdateMode, @@ -30,7 +33,7 @@ } # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. -SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection] +SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection) class ModelDeltaTracker: @@ -39,16 +42,19 @@ class ModelDeltaTracker: ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during - online training. Unique IDs or states can be retrieved by calling the get_unique() method. + online training. Unique IDs or states can be retrieved by calling the get_delta() method. Args: model (nn.Module): the model to track. consumers (List[str], optional): list of consumers to track. Each consumer will - have its own batch offset index. Every get_unique_ids invocation will - only return the new ids for the given consumer since last get_unique_ids - call. + have its own batch offset index. Every get_delta and get_delta_ids invocation will + only return the new values for the given consumer since last call. delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them. + auto_compact (bool, optional): Trigger compaction automatically during communication at each train cycle. + When set false, compaction is triggered at get_delta() call. Default: False. mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY. + fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None. + """ DEFAULT_CONSUMER: str = "default" @@ -58,41 +64,242 @@ def __init__( model: nn.Module, consumers: Optional[List[str]] = None, delete_on_read: bool = True, + auto_compact: bool = False, mode: TrackingMode = TrackingMode.ID_ONLY, + fqns_to_skip: Iterable[str] = (), ) -> None: self._model = model self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER] self._delete_on_read = delete_on_read + self._auto_compact = auto_compact self._mode = mode - pass + self._fqn_to_feature_map: Dict[str, List[str]] = {} + self._fqns_to_skip: Iterable[str] = fqns_to_skip + + # per_consumer_batch_idx is used to track the batch index for each consumer. + # This is used to retrieve the delta values for a given consumer as well as + # start_ids for compaction window. + self.per_consumer_batch_idx: Dict[str, int] = { + c: -1 for c in (consumers or [self.DEFAULT_CONSUMER]) + } + self.curr_batch_idx: int = 0 + self.curr_compact_index: int = 0 + + self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode]) + + # from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection + self.tracked_modules: Dict[str, nn.Module] = {} + self.feature_to_fqn: Dict[str, str] = {} + # Generate the mapping from FQN to feature names. + self.fqn_to_feature_names() + # Validate the mode is supported for the given module + self._validate_mode() + + # Mapping feature name to corresponding FQNs. This is used for retrieving + # the FQN associated with a given feature name in record_lookup(). + for fqn, feature_names in self._fqn_to_feature_map.items(): + for feature_name in feature_names: + if feature_name in self.feature_to_fqn: + logger.warn(f"Duplicate feature name: {feature_name} in fqn {fqn}") + continue + self.feature_to_fqn[feature_name] = fqn + logger.info(f"feature_to_fqn: {self.feature_to_fqn}") def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: """ - Record Ids from a given KeyedJaggedTensor and embeddings/ parameter states. + Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. + + This method is run post-lookup, after the embedding lookup has been performed, + as it needs access to both the input IDs and the resulting embeddings. + + This function processes the input KeyedJaggedTensor and records either just the IDs + (in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode). + + Args: + kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record. + states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt. + """ + + # In ID_ONLY mode, we only track feature IDs received in the current batch. + if self._mode == TrackingMode.ID_ONLY: + self.record_ids(kjt) + # In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch. + elif self._mode == TrackingMode.EMBEDDING: + self.record_embeddings(kjt, states) + + else: + raise NotImplementedError(f"Tracking mode {self._mode} is not supported") + + def record_ids(self, kjt: KeyedJaggedTensor) -> None: + """ + Record Ids from a given KeyedJaggedTensor. Args: kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record. - states (torch.Tensor): the states to record. """ - pass + per_table_ids: Dict[str, List[torch.Tensor]] = {} + for key in kjt.keys(): + table_fqn = self.feature_to_fqn[key] + ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, []) + ids_list.append(kjt[key].values()) + per_table_ids[table_fqn] = ids_list - def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: + for table_fqn, ids_list in per_table_ids.items(): + self.store.append( + batch_idx=self.curr_batch_idx, + table_fqn=table_fqn, + ids=torch.cat(ids_list), + embeddings=None, + ) + + def record_embeddings( + self, kjt: KeyedJaggedTensor, embeddings: torch.Tensor + ) -> None: + """ + Record Ids along with Embeddings from a given KeyedJaggedTensor and embeddings. + + Args: + kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record. + embeddings (torch.Tensor): the embeddings to record. """ - Return a dictionary of hit local IDs for each sparse feature. The IDs are first keyed by submodule FQN. + per_table_ids: Dict[str, List[torch.Tensor]] = {} + per_table_emb: Dict[str, List[torch.Tensor]] = {} + assert embeddings.numel() % kjt.values().numel() == 0, ( + f"ids and embeddings size mismatch, expect [{kjt.values().numel()} * emb_dim], " + f"but got {embeddings.numel()}" + ) + embeddings_2d = embeddings.view(kjt.values().numel(), -1) + + offset: int = 0 + for key in kjt.keys(): + table_fqn = self.feature_to_fqn[key] + ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, []) + emb_list: List[torch.Tensor] = per_table_emb.get(table_fqn, []) + + ids = kjt[key].values() + ids_list.append(ids) + emb_list.append(embeddings_2d[offset : offset + ids.numel()]) + offset += ids.numel() + + per_table_ids[table_fqn] = ids_list + per_table_emb[table_fqn] = emb_list + + for table_fqn, ids_list in per_table_ids.items(): + self.store.append( + batch_idx=self.curr_batch_idx, + table_fqn=table_fqn, + ids=torch.cat(ids_list), + embeddings=torch.cat(per_table_emb[table_fqn]), + ) + + def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: + """ + Return a dictionary of hit local IDs for each sparse feature. Ids are + first keyed by submodule FQN. Args: - consumer (str, optional): The consumer to retrieve IDs for. If not specified, "default" is used as the default consumer. + consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer. """ - return {} + per_table_delta_rows = self.get_delta(consumer) + return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()} - def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]: + def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: """ - Returns a mapping from FQN to feature names for a given module. + Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN. Args: - module (nn.Module): the module to retrieve feature names for. + consumer (str, optional): The consumer to retrieve delta values for. If not specified, "default" is used as the default consumer. """ - return {} + consumer = consumer or self.DEFAULT_CONSUMER + assert ( + consumer in self.per_consumer_batch_idx + ), f"consumer {consumer} not present in {self.per_consumer_batch_idx.values()}" + + index_end: int = self.curr_batch_idx + 1 + index_start = max(self.per_consumer_batch_idx.values()) + + # In case of multiple consumers, it is possible that the previous consumer has already compact these indices + # and index_start could be equal to index_end, in which case we should not compact again. + if index_start < index_end: + self.compact(index_start, index_end) + tracker_rows = self.store.get_delta( + from_idx=self.per_consumer_batch_idx[consumer] + ) + self.per_consumer_batch_idx[consumer] = index_end + if self._delete_on_read: + self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) + return tracker_rows + + def get_tracked_modules(self) -> Dict[str, nn.Module]: + """ + Returns a dictionary of tracked modules. + """ + return self.tracked_modules + + def fqn_to_feature_names(self) -> Dict[str, List[str]]: + """ + Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model. + """ + if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0: + return self._fqn_to_feature_map + + table_to_feature_names: Dict[str, List[str]] = OrderedDict() + table_to_fqn: Dict[str, str] = OrderedDict() + for fqn, named_module in self._model.named_modules(): + split_fqn = fqn.split(".") + # Skipping partial FQNs present in fqns_to_skip + # TODO: Validate if we need to support more complex patterns for skipping fqns + should_skip = False + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in split_fqn: + logger.info(f"Skipping {fqn} because it is part of fqns_to_skip") + should_skip = True + break + if should_skip: + continue + # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. + if isinstance(named_module, SUPPORTED_MODULES): + for table_name, config in named_module._table_name_to_config.items(): + logger.info( + f"Found {table_name} for {fqn} with features {config.feature_names}" + ) + table_to_feature_names[table_name] = config.feature_names + self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module + for table_name in table_to_feature_names: + # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn" + # will incorrectly match fqn with all the table names that have the same prefix + if table_name in split_fqn: + embedding_fqn = self._clean_fqn_fn(fqn) + if table_name in table_to_fqn: + # Sanity check for validating that we don't have more then one table mapping to same fqn. + logger.warning( + f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + ) + table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {table_to_fqn}") + flatten_names = [ + name for names in table_to_feature_names.values() for name in names + ] + # TODO: Validate if there is a better way to handle duplicate feature names. + # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case. + if len(set(flatten_names)) != len(flatten_names): + counts = Counter(flatten_names) + duplicates = [item for item, count in counts.items() if count > 1] + logger.warning(f"duplicate feature names found: {duplicates}") + + fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() + for table_name in table_to_feature_names: + if table_name not in table_to_fqn: + # This is likely unexpected, where we can't locate the FQN associated with this table. + logger.warning( + f"Table {table_name} not found in {table_to_fqn}, skipping" + ) + continue + fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ + table_name + ] + self._fqn_to_feature_map = fqn_to_feature_names + return fqn_to_feature_names def clear(self, consumer: Optional[str] = None) -> None: """ @@ -101,7 +308,19 @@ def clear(self, consumer: Optional[str] = None) -> None: Args: consumer (str, optional): The consumer to clear IDs/States for. If not specified, "default" is used as the default consumer. """ - pass + # 1. If consumer is None, delete globally. + if consumer is None: + self.store.delete() + return + + assert ( + consumer in self.per_consumer_batch_idx + ), f"consumer {consumer} not found in {self.per_consumer_batch_idx.values()}" + + # 2. For single consumer, we can just delete all ids + if len(self.per_consumer_batch_idx) == 1: + self.store.delete() + return def compact(self, start_idx: int, end_idx: int) -> None: """ @@ -111,4 +330,16 @@ def compact(self, start_idx: int, end_idx: int) -> None: start_idx (int): Starting index for compaction. end_idx (int): Ending index for compaction. """ - pass + self.store.compact(start_idx, end_idx) + + def _clean_fqn_fn(self, fqn: str) -> str: + # strip DMP internal module FQN prefix to match state dict FQN + return fqn.replace("_dmp_wrapped_module.module.", "") + + def _validate_mode(self) -> None: + "To validate the mode is supported for the given module" + for module in self.tracked_modules.values(): + assert not ( + isinstance(module, ShardedEmbeddingBagCollection) + and self._mode == TrackingMode.EMBEDDING + ), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings." diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py new file mode 100644 index 000000000..fe80c2080 --- /dev/null +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest +from dataclasses import dataclass +from typing import cast, Dict, List, Type, Union + +import torch +import torchrec +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType + +from parameterized import parameterized +from torch import nn +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding import EmbeddingCollectionSharder +from torchrec.distributed.embedding_types import ModuleSharder, ShardingType +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker +from torchrec.distributed.model_tracker.tests.utils import ( + EmbeddingTableProps, + generate_planner_constraints, + TestEBCModel, + TestECModel, +) + +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, + EmbeddingConfig, + PoolingType, +) +from torchrec.test_utils import skip_if_asan + +NUM_EMBEDDINGS: int = 16 +EMBEDDING_DIM: int = 256 + +HAS_2_GPU: bool = torch.cuda.device_count() >= 2 +HAS_1_GPU: bool = torch.cuda.device_count() >= 1 + + +# Helper function to create a model +def get_model( + rank: int, + world_size: int, + ctx: MultiProcessContext, + embedding_config_type: Union[Type[EmbeddingConfig], Type[EmbeddingBagConfig]], + embedding_tables: Dict[str, EmbeddingTableProps], +) -> DistributedModelParallel: + # Create the model + test_model = ( + TestECModel( + tables=[ + EmbeddingConfig( + name=table_name, + embedding_dim=table.embedding_dim, + num_embeddings=table.num_embeddings, + feature_names=table.feature_names, + ) + for table_name, table in embedding_tables.items() + ] + ) + if embedding_config_type == EmbeddingConfig + else TestEBCModel( + tables=[ + EmbeddingBagConfig( + name=table_name, + embedding_dim=table.embedding_dim, + num_embeddings=table.num_embeddings, + feature_names=table.feature_names, + pooling=table.pooling, + ) + for table_name, table in embedding_tables.items() + ] + ) + ) + + # Set up device + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + else: + device = torch.device("cpu") + + # Create planner and sharders + planner = EmbeddingShardingPlanner( + topology=Topology(world_size, "cuda"), + constraints=generate_planner_constraints(embedding_tables), + ) + sharders = [ + cast( + ModuleSharder[nn.Module], + EmbeddingCollectionSharder( + fused_params={ + "optimizer": OptimType.ADAM, + "beta1": 0.9, + "beta2": 0.99, + } + ), + ), + cast( + ModuleSharder[nn.Module], + EmbeddingBagCollectionSharder(fused_params={"optimizer": OptimType.ADAM}), + ), + ] + + # Create plan + plan = planner.collective_plan(test_model, sharders, ctx.pg) + + # Create DMP + if ctx.pg is None: + raise ValueError("Process group cannot be None") + + return DistributedModelParallel( + module=test_model, + device=device, + env=torchrec.distributed.ShardingEnv.from_process_group(ctx.pg), + plan=plan, + sharders=sharders, + ) + + +@dataclass +class ModelDeltaTrackerInputTestParams: + # input parameters + embedding_config_type: Union[Type[EmbeddingConfig], Type[EmbeddingBagConfig]] + embedding_tables: Dict[str, EmbeddingTableProps] + fqns_to_skip: List[str] + + +@dataclass +class FqnToFeatureNamesOutputTestParams: + # expected output parameters + expected_fqn_to_feature_names: Dict[str, List[str]] + + +class ModelDeltaTrackerTest(MultiProcessTestBase): + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, methodName="runTest") -> None: + super().__init__(methodName) + self.world_size = 2 + + @parameterized.expand( + [ + ( + "EC_model_test", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.NONE, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f4", "f5", "f6"], + pooling=PoolingType.NONE, + ), + }, + fqns_to_skip=[], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={ + "ec.embeddings.sparse_table_1": ["f1", "f2", "f3"], + "ec.embeddings.sparse_table_2": ["f4", "f5", "f6"], + }, + ), + ), + ( + "EBC_model_test", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.SUM, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f4", "f5", "f6"], + pooling=PoolingType.SUM, + ), + }, + fqns_to_skip=[], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={ + "ebc.embedding_bags.sparse_table_1": ["f1", "f2", "f3"], + "ebc.embedding_bags.sparse_table_2": ["f4", "f5", "f6"], + }, + ), + ), + ( + "EC_model_test_with_duplicate_feature_names", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.NONE, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f3", "f4", "f5"], + pooling=PoolingType.NONE, + ), + }, + fqns_to_skip=[], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={ + "ec.embeddings.sparse_table_1": ["f1", "f2", "f3"], + "ec.embeddings.sparse_table_2": ["f3", "f4", "f5"], + }, + ), + ), + ( + "fqns_to_skip_table_name", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.SUM, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f4", "f5", "f6"], + pooling=PoolingType.SUM, + ), + }, + fqns_to_skip=["sparse_table_1"], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={ + "ebc.embedding_bags.sparse_table_2": ["f4", "f5", "f6"], + }, + ), + ), + ( + "fqns_to_skip_mid_fqn", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.SUM, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f4", "f5", "f6"], + pooling=PoolingType.SUM, + ), + }, + fqns_to_skip=["embedding_bags"], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={}, + ), + ), + ( + "fqns_to_skip_parent_fqn", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables={ + "sparse_table_1": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f1", "f2", "f3"], + pooling=PoolingType.NONE, + ), + "sparse_table_2": EmbeddingTableProps( + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + sharding=ShardingType.ROW_WISE, + feature_names=["f3", "f4", "f5"], + pooling=PoolingType.NONE, + ), + }, + fqns_to_skip=["ec"], + ), + FqnToFeatureNamesOutputTestParams( + expected_fqn_to_feature_names={}, + ), + ), + ] + ) + @skip_if_asan + @unittest.skipUnless(HAS_1_GPU, reason="Test requires at least 1 GPU") + def test_fqn_to_feature_names( + self, + _test_name: str, + input_params: ModelDeltaTrackerInputTestParams, + output_params: FqnToFeatureNamesOutputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_fqn_to_feature_names, + world_size=self.world_size, + input_params=input_params, + output_params=output_params, + ) + + +def _test_fqn_to_feature_names( + rank: int, + world_size: int, + input_params: ModelDeltaTrackerInputTestParams, + output_params: FqnToFeatureNamesOutputTestParams, +) -> None: + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + # Get the model using the helper function + model = get_model( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=input_params.embedding_config_type, + embedding_tables=input_params.embedding_tables, + ) + + model_dt = ModelDeltaTracker(model, fqns_to_skip=input_params.fqns_to_skip) + actual_fqn_to_feature_names = model_dt.fqn_to_feature_names() + + unittest.TestCase().assertEqual( + actual_fqn_to_feature_names, + output_params.expected_fqn_to_feature_names, + ) diff --git a/torchrec/distributed/model_tracker/tests/utils.py b/torchrec/distributed/model_tracker/tests/utils.py new file mode 100644 index 000000000..88f6dbb64 --- /dev/null +++ b/torchrec/distributed/model_tracker/tests/utils.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + +from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.types import ShardingType +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingBagConfig, + EmbeddingConfig, + PoolingType, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +@dataclass +class EmbeddingTableProps: + """ + Properties of an embedding table. + + Args: + name (str): name of the table + num_embeddings (int): number of embeddings in the table + embedding_dim (int): dimension of each embedding + sharding_type (ShardingType): sharding type of the table + feature_names (List[str]): list of feature names associated with the table + pooling (PoolingType): pooling type of the table + data_type (DataType): data type of the table + weight_type (WeightedType): weight + """ + + num_embeddings: int + embedding_dim: int + sharding: ShardingType + feature_names: List[str] + pooling: PoolingType + data_type: DataType = DataType.FP32 + is_weighted: bool = False + + +class TestECModel(nn.Module): + """ + Test model with EmbeddingCollection and Linear layers. + + Args: + tables (List[EmbeddingConfig]): list of embedding tables + device (Optional[torch.device]): device on which buffers will be initialized + + Example: + TestECModel(tables=[EmbeddingConfig(...)]) + """ + + def __init__( + self, tables: List[EmbeddingConfig], device: Optional[torch.device] = None + ) -> None: + super().__init__() + self.ec: EmbeddingCollection = EmbeddingCollection( + tables=tables, + device=device if device else torch.device("meta"), + ) + + embedding_dim = tables[0].embedding_dim + + self.seq: nn.Sequential = nn.Sequential( + *[nn.Linear(embedding_dim, embedding_dim) for _ in range(3)] + ) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + """ + Forward pass of the TestECModel. + + Args: + features (KeyedJaggedTensor): Input features for the model. + + Returns: + torch.Tensor: Output tensor after processing through the model. + """ + + lookup_result = self.ec(features) + return self.seq(torch.cat([jt.values() for _, jt in lookup_result.items()])) + + +class TestEBCModel(nn.Module): + """ + Test model with EmbeddingBagCollection and Linear layers. + + Args: + tables (List[EmbeddingBagConfig]): list of embedding tables + device (Optional[torch.device]): device on which buffers will be initialized + + Example: + TestEBCModel(tables=[EmbeddingBagConfig(...)]) + """ + + def __init__( + self, tables: List[EmbeddingBagConfig], device: Optional[torch.device] = None + ) -> None: + super().__init__() + self.ebc: EmbeddingBagCollection + self.ebc = EmbeddingBagCollection( + tables=tables, + device=device if device else torch.device("meta"), + ) + + embedding_dim = tables[0].embedding_dim + self.seq: nn.Sequential = nn.Sequential( + *[nn.Linear(embedding_dim, embedding_dim) for _ in range(3)] + ) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + """ + Forward pass of the TestEBCModel. + + Args: + features (KeyedJaggedTensor): Input features for the model. + + Returns: + torch.Tensor: Output tensor after processing through the model. + """ + + lookup_result = self.ebc(features).to_dict() + return self.seq(torch.cat(tuple(lookup_result.values()))) + + +def create_ec_model( + tables: Dict[str, EmbeddingTableProps], + device: Optional[torch.device] = None, +) -> nn.Module: + """ + Create an EmbeddingCollection model with the given tables. + + Args: + tables (List[EmbeddingTableProps]): list of embedding tables + device (Optional[torch.device]): device on which buffers will be initialized + + Returns: + nn.Module: EmbeddingCollection model + """ + return TestECModel( + tables=[ + EmbeddingConfig( + name=name, + embedding_dim=table.embedding_dim, + num_embeddings=table.num_embeddings, + feature_names=table.feature_names, + data_type=table.data_type, + ) + for name, table in tables.items() + ], + device=device, + ) + + +def create_ebc_model( + tables: Dict[str, EmbeddingTableProps], + device: Optional[torch.device] = None, +) -> nn.Module: + """ + Create an EmbeddinBagCollection model with the given tables. + + Args: + tables (List[EmbeddingTableProps]): list of embedding tables + device (Optional[torch.device]): device on which buffers will be initialized + + Returns: + nn.Module: EmbeddingCollection model + """ + return TestEBCModel( + tables=[ + EmbeddingBagConfig( + name=name, + embedding_dim=table.embedding_dim, + num_embeddings=table.num_embeddings, + feature_names=table.feature_names, + data_type=table.data_type, + pooling=table.pooling, + ) + for name, table in tables.items() + ], + device=device, + ) + + +def generate_planner_constraints( + tables: Dict[str, EmbeddingTableProps], +) -> dict[str, ParameterConstraints]: + """ + Generate planner constraints for the given tables. + + Args: + tables (List[EmbeddingTableProps]): list of embedding tables + + Returns: + Dict[str, ParameterConstraints]: planner constraints + """ + constraints: Dict[str, ParameterConstraints] = {} + for name, table in tables.items(): + sharding_types = [table.sharding.value] + constraints[name] = ParameterConstraints( + sharding_types=sharding_types, + compute_kernels=[EmbeddingComputeKernel.FUSED.value], + feature_names=table.feature_names, + pooling_factors=[1.0], + ) + return constraints