From 998bac3c93d1ff4b8446e65561dffe458d5767e9 Mon Sep 17 00:00:00 2001 From: Facebook Community Bot <6422482+facebook-github-bot@users.noreply.github.com> Date: Fri, 12 Nov 2021 10:08:39 -0800 Subject: [PATCH] Re-sync with internal repository --- torchrec/distributed/embedding.py | 730 +--------------- torchrec/distributed/embeddingbag.py | 794 ++++++++++++++++++ torchrec/distributed/model_parallel.py | 8 +- .../planner/new/tests/test_calculators.py | 4 +- .../planner/new/tests/test_enumerators.py | 2 +- .../planner/new/tests/test_partitioners.py | 2 +- .../planner/new/tests/test_placers.py | 2 +- .../planner/new/tests/test_rankers.py | 2 +- .../planner/tests/test_embedding_planner.py | 2 +- torchrec/distributed/tests/test_model.py | 6 +- .../distributed/tests/test_model_parallel.py | 6 +- .../tests/test_quant_model_parallel.py | 7 +- .../distributed/tests/test_train_pipeline.py | 8 +- torchrec/distributed/tests/test_utils.py | 4 +- torchrec/distributed/utils.py | 12 + torchrec/examples/__init__.py | 0 torchrec/examples/dlrm/README.MD | 11 + torchrec/examples/dlrm/__init__.py | 0 torchrec/examples/dlrm/dlrm_main.py | 186 ++++ torchrec/examples/dlrm/modules/__init__.py | 0 torchrec/examples/dlrm/modules/dlrm_train.py | 71 ++ .../{ => notebooks}/criteo_tutorial.ipynb | 0 .../{ => notebooks}/movielens_tutorial.ipynb | 0 23 files changed, 1119 insertions(+), 738 deletions(-) create mode 100644 torchrec/distributed/embeddingbag.py create mode 100644 torchrec/examples/__init__.py create mode 100644 torchrec/examples/dlrm/README.MD create mode 100644 torchrec/examples/dlrm/__init__.py create mode 100644 torchrec/examples/dlrm/dlrm_main.py create mode 100644 torchrec/examples/dlrm/modules/__init__.py create mode 100644 torchrec/examples/dlrm/modules/dlrm_train.py rename torchrec/examples/{ => notebooks}/criteo_tutorial.ipynb (100%) rename torchrec/examples/{ => notebooks}/movielens_tutorial.ipynb (100%) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 166214a49..af9e176fa 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 -import copy -from collections import OrderedDict from typing import ( List, Dict, @@ -9,224 +7,30 @@ Type, Any, TypeVar, - Mapping, - Union, - Tuple, - Iterator, - Set, ) import torch -from torch import Tensor from torch import nn -from torch.distributed._sharding_spec import ( - EnumerableShardingSpec, -) -from torch.nn.modules.module import _IncompatibleKeys -from torchrec.distributed.cw_sharding import CwEmbeddingSharding -from torchrec.distributed.dp_sharding import DpEmbeddingSharding -from torchrec.distributed.embedding_sharding import ( - EmbeddingSharding, - SparseFeaturesListAwaitable, -) from torchrec.distributed.embedding_types import ( - SparseFeatures, BaseEmbeddingSharder, - EmbeddingComputeKernel, - BaseEmbeddingLookup, SparseFeaturesList, ) -from torchrec.distributed.rw_sharding import RwEmbeddingSharding -from torchrec.distributed.tw_sharding import TwEmbeddingSharding -from torchrec.distributed.twrw_sharding import TwRwEmbeddingSharding from torchrec.distributed.types import ( Awaitable, LazyAwaitable, ParameterSharding, - ParameterStorage, ShardedModule, - ShardingType, ShardedModuleContext, - ShardedTensor, - ModuleSharder, ShardingEnv, ) -from torchrec.distributed.utils import append_prefix -from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType from torchrec.modules.embedding_modules import ( - EmbeddingBagCollection, - EmbeddingBagCollectionInterface, + EmbeddingCollection, ) from torchrec.optim.fused import FusedOptimizerModule -from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer -from torchrec.quant.embedding_modules import ( - EmbeddingBagCollection as QuantEmbeddingBagCollection, -) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor -def create_embedding_sharding( - sharding_type: str, - embedding_configs: List[ - Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] - ], - env: ShardingEnv, - device: Optional[torch.device] = None, -) -> EmbeddingSharding: - pg = env.process_group - if device is not None and device.type == "meta": - replace_placement_with_meta_device(embedding_configs) - if pg is not None: - if sharding_type == ShardingType.TABLE_WISE.value: - return TwEmbeddingSharding(embedding_configs, pg, device) - elif sharding_type == ShardingType.ROW_WISE.value: - return RwEmbeddingSharding(embedding_configs, pg, device) - elif sharding_type == ShardingType.DATA_PARALLEL.value: - return DpEmbeddingSharding(embedding_configs, env, device) - elif sharding_type == ShardingType.TABLE_ROW_WISE.value: - return TwRwEmbeddingSharding(embedding_configs, pg, device) - elif sharding_type == ShardingType.COLUMN_WISE.value: - return CwEmbeddingSharding(embedding_configs, pg, device) - else: - raise ValueError(f"Sharding not supported {sharding_type}") - else: - if sharding_type == ShardingType.DATA_PARALLEL.value: - return DpEmbeddingSharding(embedding_configs, env, device) - else: - raise ValueError(f"Sharding not supported {sharding_type}") - - -def replace_placement_with_meta_device( - embedding_configs: List[ - Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] - ] -) -> None: - """Placement device and tensor device could be unmatched in some - scenarios, e.g. passing meta device to DMP and passing cuda - to EmbeddingShardingPlanner. We need to make device consistent - after getting sharding planner. - """ - for config in embedding_configs: - sharding_spec = config[1].sharding_spec - if sharding_spec is None: - continue - if isinstance(sharding_spec, EnumerableShardingSpec): - for shard_metadata in sharding_spec.shards: - placement = shard_metadata.placement - if isinstance(placement, str): - placement = torch.distributed._remote_device(placement) - assert isinstance(placement, torch.distributed._remote_device) - placement._device = torch.device("meta") - shard_metadata.placement = placement - else: - # We only support EnumerableShardingSpec at present. - raise RuntimeError( - f"Unsupported ShardingSpec {type(sharding_spec)} with meta device" - ) - - -def filter_state_dict( - state_dict: "OrderedDict[str, torch.Tensor]", name: str -) -> "OrderedDict[str, torch.Tensor]": - rtn_dict = OrderedDict() - for key, value in state_dict.items(): - if key.startswith(name): - # + 1 to length is to remove the '.' after the key - rtn_dict[key[len(name) + 1 :]] = value - return rtn_dict - - -def _create_embedding_configs_by_sharding( - module: EmbeddingBagCollectionInterface, - table_name_to_parameter_sharding: Dict[str, ParameterSharding], - prefix: str, -) -> Dict[str, List[Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor]]]: - shared_feature: Dict[str, bool] = {} - for embedding_config in module.embedding_bag_configs: - if not embedding_config.feature_names: - embedding_config.feature_names = [embedding_config.name] - for feature_name in embedding_config.feature_names: - if feature_name not in shared_feature: - shared_feature[feature_name] = False - else: - shared_feature[feature_name] = True - - sharding_type_to_embedding_configs: Dict[ - str, List[Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor]] - ] = {} - state_dict = module.state_dict() - for config in module.embedding_bag_configs: - table_name = config.name - assert table_name in table_name_to_parameter_sharding - parameter_sharding = table_name_to_parameter_sharding[table_name] - if parameter_sharding.compute_kernel not in [ - kernel.value for kernel in EmbeddingComputeKernel - ]: - raise ValueError( - f"Compute kernel not supported {parameter_sharding.compute_kernel}" - ) - embedding_names: List[str] = [] - for feature_name in config.feature_names: - if shared_feature[feature_name]: - embedding_names.append(feature_name + "@" + config.name) - else: - embedding_names.append(feature_name) - - param_name = prefix + table_name + ".weight" - assert param_name in state_dict - param = state_dict[param_name] - - if parameter_sharding.sharding_type not in sharding_type_to_embedding_configs: - sharding_type_to_embedding_configs[parameter_sharding.sharding_type] = [] - sharding_type_to_embedding_configs[parameter_sharding.sharding_type].append( - ( - EmbeddingTableConfig( - num_embeddings=config.num_embeddings, - embedding_dim=config.embedding_dim, - name=config.name, - data_type=config.data_type, - feature_names=copy.deepcopy(config.feature_names), - pooling=config.pooling, - is_weighted=module.is_weighted, - has_feature_processor=False, - embedding_names=embedding_names, - weight_init_max=config.weight_init_max, - weight_init_min=config.weight_init_min, - ), - parameter_sharding, - param, - ) - ) - return sharding_type_to_embedding_configs - - -class EmbeddingCollectionAwaitable(LazyAwaitable[KeyedTensor]): - def __init__( - self, - awaitables: List[Awaitable[torch.Tensor]], - embedding_dims: List[int], - embedding_names: List[str], - ) -> None: - super().__init__() - self._awaitables = awaitables - self._embedding_dims = embedding_dims - self._embedding_names = embedding_names - - def wait(self) -> KeyedTensor: - embeddings = [w.wait() for w in self._awaitables] - if len(embeddings) == 1: - embeddings = embeddings[0] - else: - embeddings = torch.cat(embeddings, dim=1) - return KeyedTensor( - keys=self._embedding_names, - length_per_key=self._embedding_dims, - values=embeddings, - key_dim=1, - ) - - -class ShardedEmbeddingBagCollection( +class ShardedEmbeddingCollection( ShardedModule[ SparseFeaturesList, List[torch.Tensor], @@ -235,560 +39,64 @@ class ShardedEmbeddingBagCollection( FusedOptimizerModule, ): """ - Sharded implementation of EmbeddingBagCollection. + Sharded implementation of EmbeddingCollection. This is part of public API to allow for manual data dist pipelining. """ def __init__( self, - module: EmbeddingBagCollectionInterface, + module: EmbeddingCollection, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, ) -> None: super().__init__() - sharding_type_to_embedding_configs = _create_embedding_configs_by_sharding( - module, table_name_to_parameter_sharding, "embedding_bags." - ) - self._sharding_type_to_sharding: Dict[str, EmbeddingSharding] = { - sharding_type: create_embedding_sharding( - sharding_type, embedding_confings, env, device - ) - for sharding_type, embedding_confings in sharding_type_to_embedding_configs.items() - } - - self._is_weighted: bool = module.is_weighted - self._device = device - self._create_lookups(fused_params) - self._output_dists: nn.ModuleList[nn.Module] = nn.ModuleList() - self._embedding_names: List[str] = [] - self._embedding_dims: List[int] = [] - self._input_dists: nn.ModuleList[nn.Module] = nn.ModuleList() - self._feature_splits: List[int] = [] - self._features_order: List[int] = [] - - # forward pass flow control - self._has_uninitialized_input_dist: bool = True - self._has_uninitialized_output_dist: bool = True - self._has_features_permute: bool = True - - # Get all fused optimizers and combine them. - optims = [] - for lookup in self._lookups: - for _, module in lookup.named_modules(): - if isinstance(module, FusedOptimizerModule): - # modify param keys to match EmbeddingBagCollection - params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} - for param_key, weight in module.fused_optimizer.params.items(): - params["embedding_bags." + param_key] = weight - module.fused_optimizer.params = params - optims.append(("", module.fused_optimizer)) - self._optim: CombinedOptimizer = CombinedOptimizer(optims) - - def _create_input_dist( - self, - input_feature_names: List[str], - ) -> None: - - feature_names: List[str] = [] - for sharding in self._sharding_type_to_sharding.values(): - self._input_dists.append(sharding.create_input_dist()) - feature_names.extend( - sharding.id_score_list_feature_names() - if self._is_weighted - else sharding.id_list_feature_names() - ) - self._feature_splits.append( - len( - sharding.id_score_list_feature_names() - if self._is_weighted - else sharding.id_list_feature_names() - ) - ) - - if feature_names == input_feature_names: - self._has_features_permute = False - else: - for f in feature_names: - self._features_order.append(input_feature_names.index(f)) - self.register_buffer( - "_features_order_tensor", - torch.tensor( - self._features_order, device=self._device, dtype=torch.int32 - ), - ) - - def _create_lookups( - self, - fused_params: Optional[Dict[str, Any]], - ) -> None: - self._lookups: nn.ModuleList[BaseEmbeddingLookup] = nn.ModuleList() - for sharding in self._sharding_type_to_sharding.values(): - self._lookups.append(sharding.create_lookup(fused_params)) - - def _create_output_dist(self) -> None: - for sharding in self._sharding_type_to_sharding.values(): - self._output_dists.append(sharding.create_pooled_output_dist()) - self._embedding_names.extend(sharding.embedding_names()) - self._embedding_dims.extend(sharding.embedding_dims()) # pyre-ignore [14] def input_dist( self, ctx: ShardedModuleContext, features: KeyedJaggedTensor ) -> Awaitable[SparseFeaturesList]: - if self._has_uninitialized_input_dist: - self._create_input_dist(features.keys()) - self._has_uninitialized_input_dist = False - with torch.no_grad(): - if self._has_features_permute: - features = features.permute( - self._features_order, - # pyre-ignore [6] - self._features_order_tensor, - ) - features_by_shards = features.split( - self._feature_splits, - ) - awaitables = [ - module( - SparseFeatures( - id_list_features=None - if self._is_weighted - else features_by_shard, - id_score_list_features=features_by_shard - if self._is_weighted - else None, - ) - ) - for module, features_by_shard in zip( - self._input_dists, features_by_shards - ) - ] - return SparseFeaturesListAwaitable(awaitables) + # pyre-ignore [7] + pass def compute( self, ctx: ShardedModuleContext, dist_input: SparseFeaturesList ) -> List[torch.Tensor]: - return [lookup(features) for lookup, features in zip(self._lookups, dist_input)] + # pyre-ignore [7] + pass def output_dist( self, ctx: ShardedModuleContext, output: List[torch.Tensor] ) -> LazyAwaitable[KeyedTensor]: - if self._has_uninitialized_output_dist: - self._create_output_dist() - self._has_uninitialized_output_dist = False - return EmbeddingCollectionAwaitable( - awaitables=[ - dist(embeddings) for dist, embeddings in zip(self._output_dists, output) - ], - embedding_dims=self._embedding_dims, - embedding_names=self._embedding_names, - ) - - def compute_and_output_dist( - self, ctx: ShardedModuleContext, input: SparseFeaturesList - ) -> LazyAwaitable[KeyedTensor]: - if self._has_uninitialized_output_dist: - self._create_output_dist() - self._has_uninitialized_output_dist = False - return EmbeddingCollectionAwaitable( - awaitables=[ - dist(lookup(features)) - for lookup, dist, features in zip( - self._lookups, self._output_dists, input - ) - ], - embedding_dims=self._embedding_dims, - embedding_names=self._embedding_names, - ) - - def state_dict( - self, - destination: Optional[Dict[str, Any]] = None, - prefix: str = "", - keep_vars: bool = False, - ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() - # pyre-ignore [16] - destination._metadata = OrderedDict() - for lookup in self._lookups: - lookup.state_dict(destination, prefix + "embedding_bags.", keep_vars) - return destination - - def named_modules( - self, - memo: Optional[Set[nn.Module]] = None, - prefix: str = "", - remove_duplicate: bool = True, - ) -> Iterator[Tuple[str, nn.Module]]: - yield from [(prefix, self)] - - def named_parameters( - self, prefix: str = "", recurse: bool = True - ) -> Iterator[Tuple[str, nn.Parameter]]: - for lookup in self._lookups: - yield from lookup.named_parameters( - append_prefix(prefix, "embedding_bags"), recurse - ) - - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - for lookup, sharding_type in zip( - self._lookups, self._sharding_type_to_sharding.keys() - ): - if sharding_type == ShardingType.DATA_PARALLEL.value: - continue - for name, _ in lookup.named_parameters( - append_prefix(prefix, "embedding_bags") - ): - yield name - - def named_buffers( - self, prefix: str = "", recurse: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - for lookup in self._lookups: - yield from lookup.named_buffers( - append_prefix(prefix, "embedding_bags"), recurse - ) - - def load_state_dict( - self, - state_dict: "OrderedDict[str, torch.Tensor]", - strict: bool = True, - ) -> _IncompatibleKeys: - missing_keys = [] - unexpected_keys = [] - for lookup in self._lookups: - missing, unexpected = lookup.load_state_dict( - filter_state_dict(state_dict, "embedding_bags"), - strict, - ) - missing_keys.extend(missing) - unexpected_keys.extend(unexpected) - return _IncompatibleKeys( - missing_keys=missing_keys, unexpected_keys=unexpected_keys - ) - - def sparse_grad_parameter_names( - self, - destination: Optional[List[str]] = None, - prefix: str = "", - ) -> List[str]: - destination = [] if destination is None else destination - for lookup in self._lookups: - lookup.sparse_grad_parameter_names( - destination, append_prefix(prefix, "embedding_bags") - ) - return destination - - @property - def fused_optimizer(self) -> KeyedOptimizer: - return self._optim + # pyre-ignore [7] + pass M = TypeVar("M", bound=nn.Module) -class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[M]): +class EmbeddingCollectionSharder(BaseEmbeddingSharder[M]): """ - This implementation uses non-fused EmbeddingBagCollection + This implementation uses non-fused EmbeddingCollection """ def shard( self, - module: EmbeddingBagCollection, + module: EmbeddingCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, - ) -> ShardedEmbeddingBagCollection: - return ShardedEmbeddingBagCollection( + ) -> ShardedEmbeddingCollection: + return ShardedEmbeddingCollection( module, params, env, self.fused_params, device ) def shardable_parameters( - self, module: EmbeddingBagCollection - ) -> Dict[str, nn.Parameter]: - return { - name.split(".")[0]: param - for name, param in module.embedding_bags.named_parameters() - } - - @property - def module_type(self) -> Type[EmbeddingBagCollection]: - return EmbeddingBagCollection - - -class QuantEmbeddingBagCollectionSharder(ModuleSharder[QuantEmbeddingBagCollection]): - def shard( - self, - module: QuantEmbeddingBagCollection, - params: Dict[str, ParameterSharding], - env: ShardingEnv, - device: Optional[torch.device] = None, - ) -> ShardedEmbeddingBagCollection: - return ShardedEmbeddingBagCollection(module, params, env, None, device) - - def sharding_types(self, compute_device_type: str) -> List[str]: - return [ShardingType.DATA_PARALLEL.value] - - def compute_kernels( - self, sharding_type: str, compute_device_type: str - ) -> List[str]: - return [ - EmbeddingComputeKernel.BATCHED_QUANT.value, - ] - - def storage_usage( - self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str - ) -> Dict[str, int]: - tensor_bytes = tensor.numel() * tensor.element_size() + tensor.shape[0] * 4 - assert compute_device_type in {"cuda", "cpu"} - storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} - return {storage_map[compute_device_type].value: tensor_bytes} - - def shardable_parameters( - self, module: QuantEmbeddingBagCollection + self, module: EmbeddingCollection ) -> Dict[str, nn.Parameter]: - return { - name.split(".")[-2]: param - for name, param in module.state_dict().items() - if name.endswith(".weight") - } - - @property - def module_type(self) -> Type[QuantEmbeddingBagCollection]: - return QuantEmbeddingBagCollection - - -class EmbeddingAwaitable(LazyAwaitable[torch.Tensor]): - def __init__( - self, - awaitable: Awaitable[torch.Tensor], - ) -> None: - super().__init__() - self._awaitable = awaitable - - def wait(self) -> torch.Tensor: - embedding = self._awaitable.wait() - return embedding - - -class ShardedEmbeddingBag( - ShardedModule[ - SparseFeatures, - torch.Tensor, - torch.Tensor, - ], - FusedOptimizerModule, -): - """ - Sharded implementation of nn.EmbeddingBag. - This is part of public API to allow for manual data dist pipelining. - """ - - def __init__( - self, - module: nn.EmbeddingBag, - table_name_to_parameter_sharding: Dict[str, ParameterSharding], - env: ShardingEnv, - fused_params: Optional[Dict[str, Any]] = None, - device: Optional[torch.device] = None, - ) -> None: - super().__init__() - - assert ( - len(table_name_to_parameter_sharding) == 1 - ), "expect 1 table, but got len(table_name_to_parameter_sharding)" - assert module.mode == "sum", "ShardedEmbeddingBag only supports sum pooling" - - self._dummy_embedding_table_name = "dummy_embedding_table_name" - self._dummy_feature_name = "dummy_feature_name" - self.parameter_sharding: ParameterSharding = next( - iter(table_name_to_parameter_sharding.values()) - ) - embedding_table_config = EmbeddingTableConfig( - num_embeddings=module.num_embeddings, - embedding_dim=module.embedding_dim, - name=self._dummy_embedding_table_name, - feature_names=[self._dummy_feature_name], - pooling=PoolingType.SUM, - # We set is_weighted to True for now, - # if per_sample_weights is None in forward(), - # we could assign a all-one vector to per_sample_weights - is_weighted=True, - embedding_names=[self._dummy_feature_name], - ) - - self._embedding_sharding: EmbeddingSharding = create_embedding_sharding( - sharding_type=self.parameter_sharding.sharding_type, - embedding_configs=[ - ( - embedding_table_config, - self.parameter_sharding, - next(iter(module.parameters())), - ) - ], - env=env, - device=device, - ) - self._input_dist: nn.Module = self._embedding_sharding.create_input_dist() - self._lookup: nn.Module = self._embedding_sharding.create_lookup(fused_params) - self._output_dist: nn.Module = ( - self._embedding_sharding.create_pooled_output_dist() - ) - - # Get all fused optimizers and combine them. - optims = [] - for _, module in self._lookup.named_modules(): - if isinstance(module, FusedOptimizerModule): - # modify param keys to match EmbeddingBag - params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} - for param_key, weight in module.fused_optimizer.params.items(): - params[param_key.split(".")[-1]] = weight - module.fused_optimizer.params = params - optims.append(("", module.fused_optimizer)) - self._optim: CombinedOptimizer = CombinedOptimizer(optims) - - # pyre-ignore [14] - def input_dist( - self, - ctx: ShardedModuleContext, - input: Tensor, - offsets: Optional[Tensor] = None, - per_sample_weights: Optional[Tensor] = None, - ) -> Awaitable[SparseFeatures]: - if per_sample_weights is None: - per_sample_weights = torch.ones_like(input, dtype=torch.float) - features = KeyedJaggedTensor( - keys=[self._dummy_feature_name], - values=input, - offsets=offsets, - weights=per_sample_weights, - ) - return self._input_dist( - SparseFeatures( - id_list_features=None, - id_score_list_features=features, - ) - ) - - def compute( - self, ctx: ShardedModuleContext, dist_input: SparseFeatures - ) -> torch.Tensor: - return self._lookup(dist_input) - - def output_dist( - self, ctx: ShardedModuleContext, output: torch.Tensor - ) -> LazyAwaitable[torch.Tensor]: - return EmbeddingAwaitable( - awaitable=self._output_dist(output), - ) - - def state_dict( - self, - destination: Optional[Dict[str, Any]] = None, - prefix: str = "", - keep_vars: bool = False, - ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() - # pyre-ignore [16] - destination._metadata = OrderedDict() - lookup_state_dict = self._lookup.state_dict(None, "", keep_vars) - # update key to match embeddingBag state_dict key - for key, item in lookup_state_dict.items(): - new_key = prefix + key.split(".")[-1] - destination[new_key] = item - return destination - - def named_modules( - self, - memo: Optional[Set[nn.Module]] = None, - prefix: str = "", - remove_duplicate: bool = True, - ) -> Iterator[Tuple[str, nn.Module]]: - yield from [(prefix, self)] - - def named_parameters( - self, prefix: str = "", recurse: bool = True - ) -> Iterator[Tuple[str, nn.Parameter]]: - for name, parameter in self._lookup.named_parameters("", recurse): - # update name to match embeddingBag parameter name - yield append_prefix(prefix, name.split(".")[-1]), parameter - - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - if self.parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: - yield from [] - else: - for name, _ in self._lookup.named_parameters(""): - yield append_prefix(prefix, name.split(".")[-1]) - - def named_buffers( - self, prefix: str = "", recurse: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - for name, buffer in self._lookup.named_buffers("", recurse): - yield append_prefix(prefix, name.split(".")[-1]), buffer - - def load_state_dict( - self, - state_dict: "OrderedDict[str, torch.Tensor]", - strict: bool = True, - ) -> _IncompatibleKeys: - missing_keys = [] - unexpected_keys = [] - # update key to match embeddingBag state_dict key - for key, value in state_dict.items(): - new_key = ".".join([self._dummy_embedding_table_name, key]) - state_dict[new_key] = value - state_dict.pop(key) - missing, unexpected = self._lookup.load_state_dict( - state_dict, - strict, - ) - missing_keys.extend(missing) - unexpected_keys.extend(unexpected) - - return _IncompatibleKeys( - missing_keys=missing_keys, unexpected_keys=unexpected_keys - ) - - def sparse_grad_parameter_names( - self, - destination: Optional[List[str]] = None, - prefix: str = "", - ) -> List[str]: - destination = [] if destination is None else destination - # pyre-ignore [29] - lookup_sparse_grad_parameter_names = self._lookup.sparse_grad_parameter_names( - None, "" - ) - for name in lookup_sparse_grad_parameter_names: - destination.append(name.split(".")[-1]) - return destination - - @property - def fused_optimizer(self) -> KeyedOptimizer: - return self._optim - - -class EmbeddingBagSharder(BaseEmbeddingSharder[M]): - """ - This implementation uses non-fused nn.EmbeddingBag - """ - - def shard( - self, - module: nn.EmbeddingBag, - params: Dict[str, ParameterSharding], - env: ShardingEnv, - device: Optional[torch.device] = None, - ) -> ShardedEmbeddingBag: - return ShardedEmbeddingBag(module, params, env, self.fused_params, device) - - def shardable_parameters(self, module: nn.EmbeddingBag) -> Dict[str, nn.Parameter]: - return {name: param for name, param in module.named_parameters()} + return {} @property - def module_type(self) -> Type[nn.EmbeddingBag]: - return nn.EmbeddingBag + def module_type(self) -> Type[EmbeddingCollection]: + return EmbeddingCollection diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py new file mode 100644 index 000000000..166214a49 --- /dev/null +++ b/torchrec/distributed/embeddingbag.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 + +import copy +from collections import OrderedDict +from typing import ( + List, + Dict, + Optional, + Type, + Any, + TypeVar, + Mapping, + Union, + Tuple, + Iterator, + Set, +) + +import torch +from torch import Tensor +from torch import nn +from torch.distributed._sharding_spec import ( + EnumerableShardingSpec, +) +from torch.nn.modules.module import _IncompatibleKeys +from torchrec.distributed.cw_sharding import CwEmbeddingSharding +from torchrec.distributed.dp_sharding import DpEmbeddingSharding +from torchrec.distributed.embedding_sharding import ( + EmbeddingSharding, + SparseFeaturesListAwaitable, +) +from torchrec.distributed.embedding_types import ( + SparseFeatures, + BaseEmbeddingSharder, + EmbeddingComputeKernel, + BaseEmbeddingLookup, + SparseFeaturesList, +) +from torchrec.distributed.rw_sharding import RwEmbeddingSharding +from torchrec.distributed.tw_sharding import TwEmbeddingSharding +from torchrec.distributed.twrw_sharding import TwRwEmbeddingSharding +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ParameterSharding, + ParameterStorage, + ShardedModule, + ShardingType, + ShardedModuleContext, + ShardedTensor, + ModuleSharder, + ShardingEnv, +) +from torchrec.distributed.utils import append_prefix +from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingBagCollectionInterface, +) +from torchrec.optim.fused import FusedOptimizerModule +from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +def create_embedding_sharding( + sharding_type: str, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + env: ShardingEnv, + device: Optional[torch.device] = None, +) -> EmbeddingSharding: + pg = env.process_group + if device is not None and device.type == "meta": + replace_placement_with_meta_device(embedding_configs) + if pg is not None: + if sharding_type == ShardingType.TABLE_WISE.value: + return TwEmbeddingSharding(embedding_configs, pg, device) + elif sharding_type == ShardingType.ROW_WISE.value: + return RwEmbeddingSharding(embedding_configs, pg, device) + elif sharding_type == ShardingType.DATA_PARALLEL.value: + return DpEmbeddingSharding(embedding_configs, env, device) + elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + return TwRwEmbeddingSharding(embedding_configs, pg, device) + elif sharding_type == ShardingType.COLUMN_WISE.value: + return CwEmbeddingSharding(embedding_configs, pg, device) + else: + raise ValueError(f"Sharding not supported {sharding_type}") + else: + if sharding_type == ShardingType.DATA_PARALLEL.value: + return DpEmbeddingSharding(embedding_configs, env, device) + else: + raise ValueError(f"Sharding not supported {sharding_type}") + + +def replace_placement_with_meta_device( + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ] +) -> None: + """Placement device and tensor device could be unmatched in some + scenarios, e.g. passing meta device to DMP and passing cuda + to EmbeddingShardingPlanner. We need to make device consistent + after getting sharding planner. + """ + for config in embedding_configs: + sharding_spec = config[1].sharding_spec + if sharding_spec is None: + continue + if isinstance(sharding_spec, EnumerableShardingSpec): + for shard_metadata in sharding_spec.shards: + placement = shard_metadata.placement + if isinstance(placement, str): + placement = torch.distributed._remote_device(placement) + assert isinstance(placement, torch.distributed._remote_device) + placement._device = torch.device("meta") + shard_metadata.placement = placement + else: + # We only support EnumerableShardingSpec at present. + raise RuntimeError( + f"Unsupported ShardingSpec {type(sharding_spec)} with meta device" + ) + + +def filter_state_dict( + state_dict: "OrderedDict[str, torch.Tensor]", name: str +) -> "OrderedDict[str, torch.Tensor]": + rtn_dict = OrderedDict() + for key, value in state_dict.items(): + if key.startswith(name): + # + 1 to length is to remove the '.' after the key + rtn_dict[key[len(name) + 1 :]] = value + return rtn_dict + + +def _create_embedding_configs_by_sharding( + module: EmbeddingBagCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + prefix: str, +) -> Dict[str, List[Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor]]]: + shared_feature: Dict[str, bool] = {} + for embedding_config in module.embedding_bag_configs: + if not embedding_config.feature_names: + embedding_config.feature_names = [embedding_config.name] + for feature_name in embedding_config.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + + sharding_type_to_embedding_configs: Dict[ + str, List[Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor]] + ] = {} + state_dict = module.state_dict() + for config in module.embedding_bag_configs: + table_name = config.name + assert table_name in table_name_to_parameter_sharding + parameter_sharding = table_name_to_parameter_sharding[table_name] + if parameter_sharding.compute_kernel not in [ + kernel.value for kernel in EmbeddingComputeKernel + ]: + raise ValueError( + f"Compute kernel not supported {parameter_sharding.compute_kernel}" + ) + embedding_names: List[str] = [] + for feature_name in config.feature_names: + if shared_feature[feature_name]: + embedding_names.append(feature_name + "@" + config.name) + else: + embedding_names.append(feature_name) + + param_name = prefix + table_name + ".weight" + assert param_name in state_dict + param = state_dict[param_name] + + if parameter_sharding.sharding_type not in sharding_type_to_embedding_configs: + sharding_type_to_embedding_configs[parameter_sharding.sharding_type] = [] + sharding_type_to_embedding_configs[parameter_sharding.sharding_type].append( + ( + EmbeddingTableConfig( + num_embeddings=config.num_embeddings, + embedding_dim=config.embedding_dim, + name=config.name, + data_type=config.data_type, + feature_names=copy.deepcopy(config.feature_names), + pooling=config.pooling, + is_weighted=module.is_weighted, + has_feature_processor=False, + embedding_names=embedding_names, + weight_init_max=config.weight_init_max, + weight_init_min=config.weight_init_min, + ), + parameter_sharding, + param, + ) + ) + return sharding_type_to_embedding_configs + + +class EmbeddingCollectionAwaitable(LazyAwaitable[KeyedTensor]): + def __init__( + self, + awaitables: List[Awaitable[torch.Tensor]], + embedding_dims: List[int], + embedding_names: List[str], + ) -> None: + super().__init__() + self._awaitables = awaitables + self._embedding_dims = embedding_dims + self._embedding_names = embedding_names + + def wait(self) -> KeyedTensor: + embeddings = [w.wait() for w in self._awaitables] + if len(embeddings) == 1: + embeddings = embeddings[0] + else: + embeddings = torch.cat(embeddings, dim=1) + return KeyedTensor( + keys=self._embedding_names, + length_per_key=self._embedding_dims, + values=embeddings, + key_dim=1, + ) + + +class ShardedEmbeddingBagCollection( + ShardedModule[ + SparseFeaturesList, + List[torch.Tensor], + KeyedTensor, + ], + FusedOptimizerModule, +): + """ + Sharded implementation of EmbeddingBagCollection. + This is part of public API to allow for manual data dist pipelining. + """ + + def __init__( + self, + module: EmbeddingBagCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + sharding_type_to_embedding_configs = _create_embedding_configs_by_sharding( + module, table_name_to_parameter_sharding, "embedding_bags." + ) + self._sharding_type_to_sharding: Dict[str, EmbeddingSharding] = { + sharding_type: create_embedding_sharding( + sharding_type, embedding_confings, env, device + ) + for sharding_type, embedding_confings in sharding_type_to_embedding_configs.items() + } + + self._is_weighted: bool = module.is_weighted + self._device = device + self._create_lookups(fused_params) + self._output_dists: nn.ModuleList[nn.Module] = nn.ModuleList() + self._embedding_names: List[str] = [] + self._embedding_dims: List[int] = [] + self._input_dists: nn.ModuleList[nn.Module] = nn.ModuleList() + self._feature_splits: List[int] = [] + self._features_order: List[int] = [] + + # forward pass flow control + self._has_uninitialized_input_dist: bool = True + self._has_uninitialized_output_dist: bool = True + self._has_features_permute: bool = True + + # Get all fused optimizers and combine them. + optims = [] + for lookup in self._lookups: + for _, module in lookup.named_modules(): + if isinstance(module, FusedOptimizerModule): + # modify param keys to match EmbeddingBagCollection + params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} + for param_key, weight in module.fused_optimizer.params.items(): + params["embedding_bags." + param_key] = weight + module.fused_optimizer.params = params + optims.append(("", module.fused_optimizer)) + self._optim: CombinedOptimizer = CombinedOptimizer(optims) + + def _create_input_dist( + self, + input_feature_names: List[str], + ) -> None: + + feature_names: List[str] = [] + for sharding in self._sharding_type_to_sharding.values(): + self._input_dists.append(sharding.create_input_dist()) + feature_names.extend( + sharding.id_score_list_feature_names() + if self._is_weighted + else sharding.id_list_feature_names() + ) + self._feature_splits.append( + len( + sharding.id_score_list_feature_names() + if self._is_weighted + else sharding.id_list_feature_names() + ) + ) + + if feature_names == input_feature_names: + self._has_features_permute = False + else: + for f in feature_names: + self._features_order.append(input_feature_names.index(f)) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + self._features_order, device=self._device, dtype=torch.int32 + ), + ) + + def _create_lookups( + self, + fused_params: Optional[Dict[str, Any]], + ) -> None: + self._lookups: nn.ModuleList[BaseEmbeddingLookup] = nn.ModuleList() + for sharding in self._sharding_type_to_sharding.values(): + self._lookups.append(sharding.create_lookup(fused_params)) + + def _create_output_dist(self) -> None: + for sharding in self._sharding_type_to_sharding.values(): + self._output_dists.append(sharding.create_pooled_output_dist()) + self._embedding_names.extend(sharding.embedding_names()) + self._embedding_dims.extend(sharding.embedding_dims()) + + # pyre-ignore [14] + def input_dist( + self, ctx: ShardedModuleContext, features: KeyedJaggedTensor + ) -> Awaitable[SparseFeaturesList]: + if self._has_uninitialized_input_dist: + self._create_input_dist(features.keys()) + self._has_uninitialized_input_dist = False + with torch.no_grad(): + if self._has_features_permute: + features = features.permute( + self._features_order, + # pyre-ignore [6] + self._features_order_tensor, + ) + features_by_shards = features.split( + self._feature_splits, + ) + awaitables = [ + module( + SparseFeatures( + id_list_features=None + if self._is_weighted + else features_by_shard, + id_score_list_features=features_by_shard + if self._is_weighted + else None, + ) + ) + for module, features_by_shard in zip( + self._input_dists, features_by_shards + ) + ] + return SparseFeaturesListAwaitable(awaitables) + + def compute( + self, ctx: ShardedModuleContext, dist_input: SparseFeaturesList + ) -> List[torch.Tensor]: + return [lookup(features) for lookup, features in zip(self._lookups, dist_input)] + + def output_dist( + self, ctx: ShardedModuleContext, output: List[torch.Tensor] + ) -> LazyAwaitable[KeyedTensor]: + if self._has_uninitialized_output_dist: + self._create_output_dist() + self._has_uninitialized_output_dist = False + return EmbeddingCollectionAwaitable( + awaitables=[ + dist(embeddings) for dist, embeddings in zip(self._output_dists, output) + ], + embedding_dims=self._embedding_dims, + embedding_names=self._embedding_names, + ) + + def compute_and_output_dist( + self, ctx: ShardedModuleContext, input: SparseFeaturesList + ) -> LazyAwaitable[KeyedTensor]: + if self._has_uninitialized_output_dist: + self._create_output_dist() + self._has_uninitialized_output_dist = False + return EmbeddingCollectionAwaitable( + awaitables=[ + dist(lookup(features)) + for lookup, dist, features in zip( + self._lookups, self._output_dists, input + ) + ], + embedding_dims=self._embedding_dims, + embedding_names=self._embedding_names, + ) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + for lookup in self._lookups: + lookup.state_dict(destination, prefix + "embedding_bags.", keep_vars) + return destination + + def named_modules( + self, + memo: Optional[Set[nn.Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ) -> Iterator[Tuple[str, nn.Module]]: + yield from [(prefix, self)] + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for lookup in self._lookups: + yield from lookup.named_parameters( + append_prefix(prefix, "embedding_bags"), recurse + ) + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for lookup, sharding_type in zip( + self._lookups, self._sharding_type_to_sharding.keys() + ): + if sharding_type == ShardingType.DATA_PARALLEL.value: + continue + for name, _ in lookup.named_parameters( + append_prefix(prefix, "embedding_bags") + ): + yield name + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for lookup in self._lookups: + yield from lookup.named_buffers( + append_prefix(prefix, "embedding_bags"), recurse + ) + + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + missing_keys = [] + unexpected_keys = [] + for lookup in self._lookups: + missing, unexpected = lookup.load_state_dict( + filter_state_dict(state_dict, "embedding_bags"), + strict, + ) + missing_keys.extend(missing) + unexpected_keys.extend(unexpected) + return _IncompatibleKeys( + missing_keys=missing_keys, unexpected_keys=unexpected_keys + ) + + def sparse_grad_parameter_names( + self, + destination: Optional[List[str]] = None, + prefix: str = "", + ) -> List[str]: + destination = [] if destination is None else destination + for lookup in self._lookups: + lookup.sparse_grad_parameter_names( + destination, append_prefix(prefix, "embedding_bags") + ) + return destination + + @property + def fused_optimizer(self) -> KeyedOptimizer: + return self._optim + + +M = TypeVar("M", bound=nn.Module) + + +class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[M]): + """ + This implementation uses non-fused EmbeddingBagCollection + """ + + def shard( + self, + module: EmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBagCollection: + return ShardedEmbeddingBagCollection( + module, params, env, self.fused_params, device + ) + + def shardable_parameters( + self, module: EmbeddingBagCollection + ) -> Dict[str, nn.Parameter]: + return { + name.split(".")[0]: param + for name, param in module.embedding_bags.named_parameters() + } + + @property + def module_type(self) -> Type[EmbeddingBagCollection]: + return EmbeddingBagCollection + + +class QuantEmbeddingBagCollectionSharder(ModuleSharder[QuantEmbeddingBagCollection]): + def shard( + self, + module: QuantEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBagCollection: + return ShardedEmbeddingBagCollection(module, params, env, None, device) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.DATA_PARALLEL.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [ + EmbeddingComputeKernel.BATCHED_QUANT.value, + ] + + def storage_usage( + self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str + ) -> Dict[str, int]: + tensor_bytes = tensor.numel() * tensor.element_size() + tensor.shape[0] * 4 + assert compute_device_type in {"cuda", "cpu"} + storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} + return {storage_map[compute_device_type].value: tensor_bytes} + + def shardable_parameters( + self, module: QuantEmbeddingBagCollection + ) -> Dict[str, nn.Parameter]: + return { + name.split(".")[-2]: param + for name, param in module.state_dict().items() + if name.endswith(".weight") + } + + @property + def module_type(self) -> Type[QuantEmbeddingBagCollection]: + return QuantEmbeddingBagCollection + + +class EmbeddingAwaitable(LazyAwaitable[torch.Tensor]): + def __init__( + self, + awaitable: Awaitable[torch.Tensor], + ) -> None: + super().__init__() + self._awaitable = awaitable + + def wait(self) -> torch.Tensor: + embedding = self._awaitable.wait() + return embedding + + +class ShardedEmbeddingBag( + ShardedModule[ + SparseFeatures, + torch.Tensor, + torch.Tensor, + ], + FusedOptimizerModule, +): + """ + Sharded implementation of nn.EmbeddingBag. + This is part of public API to allow for manual data dist pipelining. + """ + + def __init__( + self, + module: nn.EmbeddingBag, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + + assert ( + len(table_name_to_parameter_sharding) == 1 + ), "expect 1 table, but got len(table_name_to_parameter_sharding)" + assert module.mode == "sum", "ShardedEmbeddingBag only supports sum pooling" + + self._dummy_embedding_table_name = "dummy_embedding_table_name" + self._dummy_feature_name = "dummy_feature_name" + self.parameter_sharding: ParameterSharding = next( + iter(table_name_to_parameter_sharding.values()) + ) + embedding_table_config = EmbeddingTableConfig( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + name=self._dummy_embedding_table_name, + feature_names=[self._dummy_feature_name], + pooling=PoolingType.SUM, + # We set is_weighted to True for now, + # if per_sample_weights is None in forward(), + # we could assign a all-one vector to per_sample_weights + is_weighted=True, + embedding_names=[self._dummy_feature_name], + ) + + self._embedding_sharding: EmbeddingSharding = create_embedding_sharding( + sharding_type=self.parameter_sharding.sharding_type, + embedding_configs=[ + ( + embedding_table_config, + self.parameter_sharding, + next(iter(module.parameters())), + ) + ], + env=env, + device=device, + ) + self._input_dist: nn.Module = self._embedding_sharding.create_input_dist() + self._lookup: nn.Module = self._embedding_sharding.create_lookup(fused_params) + self._output_dist: nn.Module = ( + self._embedding_sharding.create_pooled_output_dist() + ) + + # Get all fused optimizers and combine them. + optims = [] + for _, module in self._lookup.named_modules(): + if isinstance(module, FusedOptimizerModule): + # modify param keys to match EmbeddingBag + params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} + for param_key, weight in module.fused_optimizer.params.items(): + params[param_key.split(".")[-1]] = weight + module.fused_optimizer.params = params + optims.append(("", module.fused_optimizer)) + self._optim: CombinedOptimizer = CombinedOptimizer(optims) + + # pyre-ignore [14] + def input_dist( + self, + ctx: ShardedModuleContext, + input: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + ) -> Awaitable[SparseFeatures]: + if per_sample_weights is None: + per_sample_weights = torch.ones_like(input, dtype=torch.float) + features = KeyedJaggedTensor( + keys=[self._dummy_feature_name], + values=input, + offsets=offsets, + weights=per_sample_weights, + ) + return self._input_dist( + SparseFeatures( + id_list_features=None, + id_score_list_features=features, + ) + ) + + def compute( + self, ctx: ShardedModuleContext, dist_input: SparseFeatures + ) -> torch.Tensor: + return self._lookup(dist_input) + + def output_dist( + self, ctx: ShardedModuleContext, output: torch.Tensor + ) -> LazyAwaitable[torch.Tensor]: + return EmbeddingAwaitable( + awaitable=self._output_dist(output), + ) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + lookup_state_dict = self._lookup.state_dict(None, "", keep_vars) + # update key to match embeddingBag state_dict key + for key, item in lookup_state_dict.items(): + new_key = prefix + key.split(".")[-1] + destination[new_key] = item + return destination + + def named_modules( + self, + memo: Optional[Set[nn.Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ) -> Iterator[Tuple[str, nn.Module]]: + yield from [(prefix, self)] + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for name, parameter in self._lookup.named_parameters("", recurse): + # update name to match embeddingBag parameter name + yield append_prefix(prefix, name.split(".")[-1]), parameter + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + if self.parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + yield from [] + else: + for name, _ in self._lookup.named_parameters(""): + yield append_prefix(prefix, name.split(".")[-1]) + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for name, buffer in self._lookup.named_buffers("", recurse): + yield append_prefix(prefix, name.split(".")[-1]), buffer + + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + missing_keys = [] + unexpected_keys = [] + # update key to match embeddingBag state_dict key + for key, value in state_dict.items(): + new_key = ".".join([self._dummy_embedding_table_name, key]) + state_dict[new_key] = value + state_dict.pop(key) + missing, unexpected = self._lookup.load_state_dict( + state_dict, + strict, + ) + missing_keys.extend(missing) + unexpected_keys.extend(unexpected) + + return _IncompatibleKeys( + missing_keys=missing_keys, unexpected_keys=unexpected_keys + ) + + def sparse_grad_parameter_names( + self, + destination: Optional[List[str]] = None, + prefix: str = "", + ) -> List[str]: + destination = [] if destination is None else destination + # pyre-ignore [29] + lookup_sparse_grad_parameter_names = self._lookup.sparse_grad_parameter_names( + None, "" + ) + for name in lookup_sparse_grad_parameter_names: + destination.append(name.split(".")[-1]) + return destination + + @property + def fused_optimizer(self) -> KeyedOptimizer: + return self._optim + + +class EmbeddingBagSharder(BaseEmbeddingSharder[M]): + """ + This implementation uses non-fused nn.EmbeddingBag + """ + + def shard( + self, + module: nn.EmbeddingBag, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBag: + return ShardedEmbeddingBag(module, params, env, self.fused_params, device) + + def shardable_parameters(self, module: nn.EmbeddingBag) -> Dict[str, nn.Parameter]: + return {name: param for name, param in module.named_parameters()} + + @property + def module_type(self) -> Type[nn.EmbeddingBag]: + return nn.EmbeddingBag diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 178571fd3..10f70870c 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -8,10 +8,9 @@ from torch import nn from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel -from torchrec.distributed.embedding import ( +from torchrec.distributed.embeddingbag import ( EmbeddingBagCollectionSharder, QuantEmbeddingBagCollectionSharder, - filter_state_dict, ) from torchrec.distributed.planner import EmbeddingShardingPlanner, sharder_name from torchrec.distributed.types import ( @@ -21,6 +20,7 @@ ShardingEnv, ) from torchrec.distributed.utils import append_prefix +from torchrec.distributed.utils import filter_state_dict from torchrec.optim.fused import FusedOptimizerModule from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer @@ -53,7 +53,9 @@ def init_weights(m): device: this device, defaults to cpu, plan: plan to use when sharding, defaults to EmbeddingShardingPlanner.collective_plan(), sharders: ModuleSharders available to shard with, defaults to EmbeddingBagCollectionSharder(), - init_data_parallel: data-parallel modules can be lazy, i.e. they delay parameter initialization until the first forward pass. Pass True if that's a case to delay initialization of data parallel modules. Do first forward pass and then call DistributedModelParallel.init_data_parallel(). + init_data_parallel: data-parallel modules can be lazy, i.e. they delay parameter initialization until + the first forward pass. Pass True if that's a case to delay initialization of data parallel modules. + Do first forward pass and then call DistributedModelParallel.init_data_parallel(). init_parameters: initialize parameters for modules still on meta device. Call Args: diff --git a/torchrec/distributed/planner/new/tests/test_calculators.py b/torchrec/distributed/planner/new/tests/test_calculators.py index 8f823001d..a42fec6da 100644 --- a/torchrec/distributed/planner/new/tests/test_calculators.py +++ b/torchrec/distributed/planner/new/tests/test_calculators.py @@ -2,10 +2,10 @@ import unittest -from torchrec.distributed.embedding import ( +from torchrec.distributed.embedding_types import EmbeddingTableConfig +from torchrec.distributed.embeddingbag import ( EmbeddingBagCollectionSharder, ) -from torchrec.distributed.embedding_types import EmbeddingTableConfig from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator from torchrec.distributed.planner.new.enumerators import ShardingEnumerator from torchrec.distributed.planner.new.types import Topology diff --git a/torchrec/distributed/planner/new/tests/test_enumerators.py b/torchrec/distributed/planner/new/tests/test_enumerators.py index fe301ae49..a195533f7 100644 --- a/torchrec/distributed/planner/new/tests/test_enumerators.py +++ b/torchrec/distributed/planner/new/tests/test_enumerators.py @@ -4,10 +4,10 @@ import unittest from typing import List -from torchrec.distributed.embedding import EmbeddingBagCollectionSharder from torchrec.distributed.embedding_types import ( EmbeddingComputeKernel, ) +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.new.constants import ( BIGINT_DTYPE, ) diff --git a/torchrec/distributed/planner/new/tests/test_partitioners.py b/torchrec/distributed/planner/new/tests/test_partitioners.py index b6915a8ec..e84fafc47 100644 --- a/torchrec/distributed/planner/new/tests/test_partitioners.py +++ b/torchrec/distributed/planner/new/tests/test_partitioners.py @@ -5,8 +5,8 @@ from typing import List from torch import nn -from torchrec.distributed.embedding import EmbeddingBagCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.new.enumerators import ShardingEnumerator from torchrec.distributed.planner.new.partitioners import GreedyCostPartitioner from torchrec.distributed.planner.new.types import Storage, Topology, PartitionByType diff --git a/torchrec/distributed/planner/new/tests/test_placers.py b/torchrec/distributed/planner/new/tests/test_placers.py index 5e2f58683..ceff6770c 100644 --- a/torchrec/distributed/planner/new/tests/test_placers.py +++ b/torchrec/distributed/planner/new/tests/test_placers.py @@ -5,8 +5,8 @@ import torch from torch import nn -from torchrec.distributed.embedding import EmbeddingBagCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator from torchrec.distributed.planner.new.enumerators import ShardingEnumerator from torchrec.distributed.planner.new.partitioners import GreedyCostPartitioner diff --git a/torchrec/distributed/planner/new/tests/test_rankers.py b/torchrec/distributed/planner/new/tests/test_rankers.py index cce4f7ba8..29d2e10d3 100644 --- a/torchrec/distributed/planner/new/tests/test_rankers.py +++ b/torchrec/distributed/planner/new/tests/test_rankers.py @@ -2,7 +2,7 @@ import unittest -from torchrec.distributed.embedding import ( +from torchrec.distributed.embeddingbag import ( EmbeddingBagCollectionSharder, ) from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator diff --git a/torchrec/distributed/planner/tests/test_embedding_planner.py b/torchrec/distributed/planner/tests/test_embedding_planner.py index ed5758d2c..850aa1bf0 100644 --- a/torchrec/distributed/planner/tests/test_embedding_planner.py +++ b/torchrec/distributed/planner/tests/test_embedding_planner.py @@ -5,8 +5,8 @@ from unittest.mock import MagicMock, patch, call from torch.distributed._sharding_spec import ShardMetadata, EnumerableShardingSpec -from torchrec.distributed.embedding import EmbeddingBagCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.embedding_planner import EmbeddingShardingPlanner from torchrec.distributed.planner.parameter_sharding import _rw_shard_table_rows from torchrec.distributed.planner.types import ParameterHints diff --git a/torchrec/distributed/tests/test_model.py b/torchrec/distributed/tests/test_model.py index 069cc7af8..0da6122ab 100644 --- a/torchrec/distributed/tests/test_model.py +++ b/torchrec/distributed/tests/test_model.py @@ -5,11 +5,11 @@ import torch import torch.nn as nn -from torchrec.distributed.embedding import ( - EmbeddingBagCollectionSharder, +from torchrec.distributed.embedding_types import EmbeddingTableConfig +from torchrec.distributed.embeddingbag import ( EmbeddingBagSharder, + EmbeddingBagCollectionSharder, ) -from torchrec.distributed.embedding_types import EmbeddingTableConfig from torchrec.modules.embedding_configs import EmbeddingBagConfig, BaseEmbeddingConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor diff --git a/torchrec/distributed/tests/test_model_parallel.py b/torchrec/distributed/tests/test_model_parallel.py index 924500c98..628195017 100644 --- a/torchrec/distributed/tests/test_model_parallel.py +++ b/torchrec/distributed/tests/test_model_parallel.py @@ -13,11 +13,9 @@ import torch.nn as nn from fbgemm_gpu.split_embedding_configs import EmbOptimType from hypothesis import Verbosity, given, settings -from torchrec.distributed.embedding import ( - EmbeddingBagCollectionSharder, - EmbeddingBagSharder, -) from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.embeddingbag import EmbeddingBagSharder from torchrec.distributed.model_parallel import ( DistributedModelParallel, default_sharders, diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py index 2caa08551..6a84ed70b 100644 --- a/torchrec/distributed/tests/test_quant_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -1,15 +1,12 @@ #!/usr/bin/env python3 import copy -import os import unittest from typing import List import torch -import torch.distributed as dist from torch import nn from torch import quantization as quant -from torchrec.distributed.embedding import QuantEmbeddingBagCollectionSharder from torchrec.distributed.embedding_lookup import ( GroupedEmbeddingBag, BatchedFusedEmbeddingBag, @@ -17,6 +14,9 @@ QuantBatchedEmbeddingBag, ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import ( + QuantEmbeddingBagCollectionSharder, +) from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.tests.test_model import ( TestSparseNN, @@ -29,7 +29,6 @@ from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, ) -from torchrec.tests.utils import get_free_port class TestQuantEBCSharder(QuantEmbeddingBagCollectionSharder): diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py index 5f05edb37..660b18fdc 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -9,14 +9,14 @@ import torch.distributed as dist from torch import nn, optim from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding import ( - ShardedEmbeddingBagCollection, - EmbeddingBagCollectionSharder, -) from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embedding_types import ( SparseFeaturesList, ) +from torchrec.distributed.embeddingbag import ( + ShardedEmbeddingBagCollection, + EmbeddingBagCollectionSharder, +) from torchrec.distributed.tests.test_model import ( TestSparseNN, ModelInput, diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index 05fa05dfc..f4b4140e1 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -10,10 +10,10 @@ import numpy as np import torch import torch.distributed as dist -from torchrec.distributed.embedding import ( +from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all +from torchrec.distributed.embeddingbag import ( EmbeddingBagCollectionSharder, ) -from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.tests.test_model import TestSparseNN from torchrec.distributed.utils import get_unsharded_module_names diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 43e5953d7..5a46babbc 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +from collections import OrderedDict from typing import List, Set, Union import torch @@ -13,6 +14,17 @@ def append_prefix(prefix: str, name: str) -> str: return prefix + name +def filter_state_dict( + state_dict: "OrderedDict[str, torch.Tensor]", name: str +) -> "OrderedDict[str, torch.Tensor]": + rtn_dict = OrderedDict() + for key, value in state_dict.items(): + if key.startswith(name): + # + 1 to length is to remove the '.' after the key + rtn_dict[key[len(name) + 1 :]] = value + return rtn_dict + + def _get_unsharded_module_names_helper( model: torch.nn.Module, path: str, diff --git a/torchrec/examples/__init__.py b/torchrec/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/examples/dlrm/README.MD b/torchrec/examples/dlrm/README.MD new file mode 100644 index 000000000..197721d46 --- /dev/null +++ b/torchrec/examples/dlrm/README.MD @@ -0,0 +1,11 @@ +# Running + +## Torchx +We recommend using [torchx](https://pytorch.org/torchx/main/quickstart.html) to run. +Here we use the [DDP builtin](https://pytorch.org/torchx/main/components/distributed.html) + +1. pip install torchx +2. (optional) setup a slurm or kubernetes cluster +3. + a. locally: torchx run dist.ddp -j 1x2 --script dlrm_main.py + b. remotely: torchx run -s slurm dist.ddp -j 1x8 --script dlrm_main.py diff --git a/torchrec/examples/dlrm/__init__.py b/torchrec/examples/dlrm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/examples/dlrm/dlrm_main.py b/torchrec/examples/dlrm/dlrm_main.py new file mode 100644 index 000000000..10dafaafd --- /dev/null +++ b/torchrec/examples/dlrm/dlrm_main.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +import argparse +import os +import sys +from typing import List + +import torch +from pyre_extensions import none_throws +from torch import distributed as dist +from torch.utils.data import DataLoader +from torchrec import EmbeddingBagCollection +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.distributed import TrainPipelineSparseDist +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.examples.dlrm.modules.dlrm_train import DLRMTrain +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.optim.keyed import KeyedOptimizerWrapper + + +# TODO(T102703283): Clean up configuration options for main module for OSS. +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="torchrec + lightning app") + parser.add_argument( + "--epochs", type=int, default=3, help="number of epochs to train" + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="batch size to use for training" + ) + parser.add_argument( + "--dataset_name", + type=str, + default="criteo_1t", + help="dataset for experiment, current support criteo_1tb, criteo_kaggle", + ) + parser.add_argument( + "--num_workers", + type=int, + default=2, + help="number of dataloader workers", + ) + parser.add_argument( + "--num_embeddings", + type=int, + default=100_000, + help="max_ind_size. The number of embeddings in each embedding table. Defaults" + " to 100_000 if num_embeddings_per_feature is not supplied.", + ) + parser.add_argument( + "--num_embeddings_per_feature", + type=str, + default=None, + help="Comma separated max_ind_size per sparse feature. The number of embeddings" + " in each embedding table. 26 values are expected for the Criteo dataset.", + ) + parser.add_argument( + "--dense_arch_layer_sizes", + type=str, + default="512,256,64", + help="Comma separated layer sizes for dense arch.", + ) + parser.add_argument( + "--over_arch_layer_sizes", + type=str, + default="512,512,256,1", + help="Comma separated layer sizes for over arch.", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=64, + help="Size of each embedding.", + ) + parser.add_argument( + "--undersampling_rate", + type=float, + help="Desired proportion of zero-labeled samples to retain (i.e. undersampling zero-labeled rows)." + " Ex. 0.3 indicates only 30pct of the rows with label 0 will be kept." + " All rows with label 1 will be kept. Value should be between 0 and 1." + " When not supplied, no undersampling occurs.", + ) + parser.add_argument( + "--seed", + type=float, + help="Random seed for reproducibility.", + ) + parser.add_argument( + "--pin_memory", + dest="pin_memory", + action="store_true", + help="Use pinned memory when loading data.", + ) + parser.set_defaults(pin_memory=False) + return parser.parse_args(argv) + + +def main(argv: List[str]) -> None: + args = parse_args(argv) + + rank = int(os.environ["LOCAL_RANK"]) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + backend = "nccl" + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + backend = "gloo" + + if not torch.distributed.is_initialized(): + dist.init_process_group(backend=backend) + + if args.num_embeddings_per_feature is not None: + num_embeddings_per_feature = list( + map(int, args.num_embeddings_per_feature.split(",")) + ) + num_embeddings = None + else: + num_embeddings_per_feature = None + num_embeddings = args.num_embeddings + + dataloader = DataLoader( + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=num_embeddings, + hash_sizes=num_embeddings_per_feature, + manual_seed=args.seed, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ), + batch_size=None, + batch_sampler=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + ) + iterator = iter(dataloader) + # TODO add criteo support and add random_dataloader arg + + eb_configs = [ + EmbeddingBagConfig( + name=f"t_{feature_name}", + embedding_dim=args.embedding_dim, + num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx] + if num_embeddings is None + else num_embeddings, + feature_names=[feature_name], + ) + for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) + ] + sharded_module_kwargs = {} + if args.over_arch_layer_sizes is not None: + sharded_module_kwargs["over_arch_layer_sizes"] = list( + map(int, args.over_arch_layer_sizes.split(",")) + ) + + train_model = DLRMTrain( + embedding_bag_collection=EmbeddingBagCollection( + tables=eb_configs, device=torch.device("meta") + ), + dense_in_features=len(DEFAULT_INT_NAMES), + dense_arch_layer_sizes=list(map(int, args.dense_arch_layer_sizes.split(","))), + over_arch_layer_sizes=list(map(int, args.over_arch_layer_sizes.split(","))), + dense_device=device, + ) + + model = DistributedModelParallel( + module=train_model, + device=device, + ) + optimizer = KeyedOptimizerWrapper( + dict(model.named_parameters()), + lambda params: torch.optim.SGD(params, lr=0.01), + ) + + train_pipeline = TrainPipelineSparseDist( + model, + optimizer, + device, + ) + + for _ in range(10): + loss, logits, labels = train_pipeline.progress(iterator) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchrec/examples/dlrm/modules/__init__.py b/torchrec/examples/dlrm/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/examples/dlrm/modules/dlrm_train.py b/torchrec/examples/dlrm/modules/dlrm_train.py new file mode 100644 index 000000000..87d27435e --- /dev/null +++ b/torchrec/examples/dlrm/modules/dlrm_train.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +from typing import Tuple, Optional, List + +import torch +from torch import nn +from torchrec.datasets.utils import Batch +from torchrec.models.dlrm import DLRM +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +class DLRMTrain(nn.Module): + """ + nn.Module to wrap DLRM model to use with train_pipeline. + + DLRM Recsys model from "Deep Learning Recommendation Model for Personalization and + Recommendation Systems" (https://arxiv.org/abs/1906.00091). Processes sparse + features by learning pooled embeddings for each feature. Learns the relationship + between dense features and sparse features by projecting dense features into the + same embedding space. Also, learns the pairwise relationships between sparse + features. + + The module assumes all sparse features have the same embedding dimension + (i.e, each EmbeddingBagConfig uses the same embedding_dim) + + Constructor Args: + embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags + used to define SparseArch. + dense_in_features (int): the dimensionality of the dense input features. + dense_arch_layer_sizes (list[int]): the layer sizes for the DenseArch. + over_arch_layer_sizes (list[int]): the layer sizes for the OverArch. NOTE: The + output dimension of the InteractionArch should not be manually specified + here. + dense_device: (Optional[torch.device]). + + Call Args: + batch: batch used with criteo and random data from torchrec.datasets + + Returns: + Tuple[loss, Tuple[loss, logits, labels]] + + Example: + >>> TODO + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + dense_in_features: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + dense_device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self.model = DLRM( + embedding_bag_collection=embedding_bag_collection, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=dense_arch_layer_sizes, + over_arch_layer_sizes=over_arch_layer_sizes, + dense_device=dense_device, + ) + self.loss_fn: nn.Module = nn.BCEWithLogitsLoss() + + def forward( + self, batch: Batch + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + logits = self.model(batch.dense_features, batch.sparse_features) + logits = logits.squeeze() + loss = self.loss_fn(logits, batch.labels.float()) + + return loss, (loss.detach(), logits.detach(), batch.labels.detach()) diff --git a/torchrec/examples/criteo_tutorial.ipynb b/torchrec/examples/notebooks/criteo_tutorial.ipynb similarity index 100% rename from torchrec/examples/criteo_tutorial.ipynb rename to torchrec/examples/notebooks/criteo_tutorial.ipynb diff --git a/torchrec/examples/movielens_tutorial.ipynb b/torchrec/examples/notebooks/movielens_tutorial.ipynb similarity index 100% rename from torchrec/examples/movielens_tutorial.ipynb rename to torchrec/examples/notebooks/movielens_tutorial.ipynb