From 804e6ffca6fa838fb15b929a7ebaa6bb6a8902df Mon Sep 17 00:00:00 2001 From: Leon Gao Date: Tue, 11 Jan 2022 17:39:29 -0800 Subject: [PATCH] multi-threading sparse arch for GPU inference Reviewed By: zyan0 Differential Revision: D31853854 fbshipit-source-id: ab263f039f035b4f8877fcd0acde16e336673f2f --- torchrec/distributed/cw_sharding.py | 13 +- torchrec/distributed/dist_data.py | 75 ++++++- torchrec/distributed/dp_sharding.py | 44 ++++- torchrec/distributed/embedding_lookup.py | 101 +++++++++- torchrec/distributed/embedding_sharding.py | 131 ++++++++++-- torchrec/distributed/embedding_types.py | 33 +++- torchrec/distributed/embeddingbag.py | 186 +++++++++--------- torchrec/distributed/model_parallel.py | 40 +++- torchrec/distributed/quant_embeddingbag.py | 184 +++++++++++++++++ torchrec/distributed/rw_sharding.py | 22 ++- .../tests/test_quant_model_parallel.py | 4 +- .../distributed/tests/test_train_pipeline.py | 4 +- torchrec/distributed/tw_sharding.py | 107 ++++++++-- torchrec/distributed/twrw_sharding.py | 22 ++- torchrec/distributed/types.py | 2 +- 15 files changed, 796 insertions(+), 172 deletions(-) create mode 100644 torchrec/distributed/quant_embeddingbag.py diff --git a/torchrec/distributed/cw_sharding.py b/torchrec/distributed/cw_sharding.py index ae691fbbe..17a8af23d 100644 --- a/torchrec/distributed/cw_sharding.py +++ b/torchrec/distributed/cw_sharding.py @@ -8,7 +8,7 @@ from typing import Set, Callable, Dict, List, Optional, Tuple import torch -import torch.distributed as dist +import torch.distributed as dist # noqa from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings from torchrec.distributed.embedding_types import ( ShardedEmbeddingTable, @@ -16,6 +16,7 @@ ) from torchrec.distributed.tw_sharding import TwEmbeddingSharding, TwPooledEmbeddingDist from torchrec.distributed.types import ( + ShardingEnv, ShardedTensorMetadata, ShardMetadata, ParameterSharding, @@ -34,13 +35,12 @@ def __init__( embedding_configs: List[ Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] ], - # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type. - pg: dist.ProcessGroup, + env: ShardingEnv, device: Optional[torch.device] = None, permute_embeddings: bool = False, ) -> None: super().__init__( - embedding_configs, pg, device, permute_embeddings=permute_embeddings + embedding_configs, env, device, permute_embeddings=permute_embeddings ) if self._permute_embeddings: self._init_combined_embeddings() @@ -162,7 +162,10 @@ def embedding_names(self) -> List[str]: else super().embedding_names() ) - def create_pooled_output_dist(self) -> TwPooledEmbeddingDist: + def create_train_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> TwPooledEmbeddingDist: embedding_permute_op: Optional[PermutePooledEmbeddings] = None callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None if self._permute_embeddings and self._embedding_order != list( diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 7d4fb8db5..4f01864be 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -17,12 +17,16 @@ alltoall_sequence, reduce_scatter_pooled, ) -from torchrec.distributed.types import Awaitable +from torchrec.distributed.types import Awaitable, NoWait from torchrec.sparse.jagged_tensor import KeyedJaggedTensor try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" + ) except OSError: pass @@ -416,6 +420,54 @@ def forward( ) +class KJTOneToAll(nn.Module): + """ + Redistributes KeyedJaggedTensor to all devices. + + Implementation utilizes OnetoAll function, which essentially P2P copies the feature to the devices. + + Constructor Args: + splits (List[int]): The lengths of features to split the KeyJaggedTensor features before copying + them. + world_size (int): the number of all devices. + recat (torch.Tensor): recat tensor for reordering tensor order after all2all. + + Call Args: + kjt (KeyedJaggedTensor): The input features. + + Returns: + Awaitable[List[KeyedJaggedTensor]]. + """ + def __init__( + self, + splits: List[int], + world_size: int, + ) -> None: + super().__init__() + self._splits = splits + self._world_size = world_size + assert self._world_size == len(splits) + + def forward(self, kjt: KeyedJaggedTensor) -> Awaitable[List[KeyedJaggedTensor]]: + """ + Split featuers first and then send the slices to the corresponding devices. + + + Call Args: + input (KeyedJaggedTensor): KeyedJaggedTensor of values to distribute. + + Returns: + Awaitable[List[KeyedJaggedTensor]]: awaitable of the KeyedJaggedTensor splits. + + """ + kjts: List[KeyedJaggedTensor] = kjt.split(self._splits) + dist_kjts = [ + split_kjt.to(torch.device("cuda", rank), non_blocking=True) + for rank, split_kjt in enumerate(kjts) + ] + return NoWait(dist_kjts) + + class PooledEmbeddingsAwaitable(Awaitable[torch.Tensor]): """ Awaitable for pooled embeddings after collective operation. @@ -541,6 +593,27 @@ def callbacks(self) -> List[Callable[[torch.Tensor], torch.Tensor]]: return self._callbacks +class PooledEmbeddingsAllToOne(nn.Module): + def __init__( + self, + device: torch.device, + world_size: int, + ) -> None: + super().__init__() + self._device = device + self._world_size = world_size + + def forward(self, tensors: List[torch.Tensor]) -> Awaitable[torch.Tensor]: + assert len(tensors) == self._world_size + return NoWait( + torch.ops.fbgemm.merge_pooled_embeddings( + tensors, + tensors[0].size(0), + self._device, + ) + ) + + class PooledEmbeddingsReduceScatter(nn.Module): """ The module class that wraps reduce-scatter communication primitive for pooled diff --git a/torchrec/distributed/dp_sharding.py b/torchrec/distributed/dp_sharding.py index 5cc218b3d..2009b17ad 100644 --- a/torchrec/distributed/dp_sharding.py +++ b/torchrec/distributed/dp_sharding.py @@ -33,7 +33,7 @@ from torchrec.modules.embedding_configs import EmbeddingTableConfig -class DpSparseFeaturesDist(BaseSparseFeaturesDist): +class DpSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): """ Distributes sparse features (input) to be data-parallel. """ @@ -58,7 +58,7 @@ def forward( return NoWait(cast(Awaitable[SparseFeatures], NoWait(sparse_features))) -class DpPooledEmbeddingDist(BasePooledEmbeddingDist): +class DpPooledEmbeddingDist(BasePooledEmbeddingDist[torch.Tensor]): """ Distributes pooled embeddings to be data-parallel. """ @@ -104,7 +104,9 @@ def forward( return NoWait(local_embs) -class DpEmbeddingSharding(EmbeddingSharding): +class DpEmbeddingSharding( + EmbeddingSharding[SparseFeatures, torch.Tensor, SparseFeatures, torch.Tensor] +): """ Shards embedding bags using data-parallel, with no table sharding i.e.. a given embedding table is replicated across all ranks. @@ -123,6 +125,8 @@ def __init__( self._env = env self._device = device self._is_sequence = is_sequence + self._rank: int = self._env.rank + self._world_size: int = self._env.world_size sharded_tables_per_rank = self._shard(embedding_configs) self._grouped_embedding_configs_per_rank: List[ List[GroupedEmbeddingConfig] @@ -147,7 +151,7 @@ def _shard( Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] ], ) -> List[List[ShardedEmbeddingTable]]: - world_size = self._env.world_size + world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ [] for i in range(world_size) ] @@ -175,10 +179,10 @@ def _shard( ) return tables_per_rank - def create_input_dist(self) -> DpSparseFeaturesDist: + def create_train_input_dist(self) -> DpSparseFeaturesDist: return DpSparseFeaturesDist() - def create_lookup( + def create_train_lookup( self, fused_params: Optional[Dict[str, Any]], feature_processor: Optional[BaseGroupedFeatureProcessor] = None, @@ -200,12 +204,36 @@ def create_lookup( feature_processor=feature_processor, ) - def create_pooled_output_dist(self) -> DpPooledEmbeddingDist: + def create_train_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> DpPooledEmbeddingDist: return DpPooledEmbeddingDist() - def create_sequence_output_dist(self) -> DpSequenceEmbeddingDist: + def create_train_sequence_output_dist(self) -> DpSequenceEmbeddingDist: return DpSequenceEmbeddingDist() + def create_infer_input_dist(self) -> DpSparseFeaturesDist: + return DpSparseFeaturesDist() + + def create_infer_lookup( + self, + fused_params: Optional[Dict[str, Any]], + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup[SparseFeatures, torch.Tensor]: + return GroupedPooledEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs, + grouped_score_configs=self._score_grouped_embedding_configs, + fused_params=fused_params, + device=self._device, + ) + + def create_infer_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> DpPooledEmbeddingDist: + return DpPooledEmbeddingDist() + def embedding_dims(self) -> List[int]: embedding_dims = [] for grouped_config in self._grouped_embedding_configs: diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 4c0d4f171..460a94edb 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -29,6 +29,7 @@ from torch.nn.modules.module import _IncompatibleKeys from torchrec.distributed.embedding_types import ( ShardedEmbeddingTable, + SparseFeaturesList, GroupedEmbeddingConfig, BaseEmbeddingLookup, SparseFeatures, @@ -681,7 +682,7 @@ def named_parameters( ) -class GroupedEmbeddingsLookup(BaseEmbeddingLookup): +class GroupedEmbeddingsLookup(BaseEmbeddingLookup[SparseFeatures, torch.Tensor]): def __init__( self, grouped_configs: List[GroupedEmbeddingConfig], @@ -1238,6 +1239,7 @@ def __init__( self._local_rows, config.embedding_tables ) ], + device=device, pooling_mode=self._pooling, feature_table_map=self._feature_table_map, ) @@ -1371,7 +1373,7 @@ def _to_data_type(dtype: torch.dtype) -> DataType: return ret -class GroupedPooledEmbeddingsLookup(BaseEmbeddingLookup): +class GroupedPooledEmbeddingsLookup(BaseEmbeddingLookup[SparseFeatures, torch.Tensor]): def __init__( self, grouped_configs: List[GroupedEmbeddingConfig], @@ -1383,6 +1385,7 @@ def __init__( ) -> None: def _create_lookup( config: GroupedEmbeddingConfig, + device: Optional[torch.device] = None, ) -> BaseEmbeddingBag: if config.compute_kernel == EmbeddingComputeKernel.BATCHED_DENSE: return BatchedDenseEmbeddingBag( @@ -1425,13 +1428,13 @@ def _create_lookup( # take parameters. self._emb_modules: nn.ModuleList[BaseEmbeddingBag] = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config)) + self._emb_modules.append(_create_lookup(config, device)) # pyre-fixme[24]: Non-generic type `nn.modules.container.ModuleList` cannot # take parameters. self._score_emb_modules: nn.ModuleList[BaseEmbeddingBag] = nn.ModuleList() for config in grouped_score_configs: - self._score_emb_modules.append(_create_lookup(config)) + self._score_emb_modules.append(_create_lookup(config, device)) self._id_list_feature_splits: List[int] = [] for config in grouped_configs: @@ -1560,3 +1563,93 @@ def sparse_grad_parameter_names( for emb_module in self._score_emb_modules: emb_module.sparse_grad_parameter_names(destination, prefix) return destination + + +class InferGroupedPooledEmbeddingsLookup( + BaseEmbeddingLookup[SparseFeaturesList, List[torch.Tensor]] +): + def __init__( + self, + grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], + grouped_score_configs_per_rank: List[List[GroupedEmbeddingConfig]], + world_size: int, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self._embedding_lookups_per_rank: List[GroupedPooledEmbeddingsLookup] = [] + for rank in range(world_size): + self._embedding_lookups_per_rank.append( + GroupedPooledEmbeddingsLookup( + grouped_configs=grouped_configs_per_rank[rank], + grouped_score_configs=grouped_score_configs_per_rank[rank], + fused_params=fused_params, + device=torch.device("cuda", rank), + ) + ) + + def forward( + self, + sparse_features: SparseFeaturesList, + ) -> List[torch.Tensor]: + embeddings: List[torch.Tensor] = [] + for sparse_features_rank, embedding_lookup in zip( + sparse_features, self._embedding_lookups_per_rank + ): + assert ( + sparse_features_rank.id_list_features is not None + or sparse_features_rank.id_score_list_features is not None + ) + embeddings.append(embedding_lookup(sparse_features_rank)) + return embeddings + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + for rank_modules in self._embedding_lookups_per_rank: + rank_modules.state_dict(destination, prefix, keep_vars) + + return destination + + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + missing_keys = [] + unexpected_keys = [] + for rank_modules in self._embedding_lookups_per_rank: + incompatible_keys = rank_modules.load_state_dict(state_dict) + missing_keys.extend(incompatible_keys.missing_keys) + unexpected_keys.extend(incompatible_keys.unexpected_keys) + return _IncompatibleKeys( + missing_keys=missing_keys, unexpected_keys=unexpected_keys + ) + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for rank_modules in self._embedding_lookups_per_rank: + yield from rank_modules.named_parameters(prefix, recurse) + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for rank_modules in self._embedding_lookups_per_rank: + yield from rank_modules.named_buffers(prefix, recurse) + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + for rank_modules in self._embedding_lookups_per_rank: + rank_modules.sparse_grad_parameter_names(destination, prefix) + return destination diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index bad0f5b6f..ac5929fdc 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -7,7 +7,7 @@ import abc from dataclasses import dataclass, field -from typing import List, Tuple, Optional, Dict, Any +from typing import TypeVar, Generic, List, Tuple, Optional, Dict, Any import torch import torch.distributed as dist @@ -15,6 +15,7 @@ from torch.distributed._sharding_spec import ShardMetadata from torchrec.distributed.dist_data import ( KJTAllToAll, + KJTOneToAll, KJTAllToAllIndicesAwaitable, ) from torchrec.distributed.embedding_types import ( @@ -25,8 +26,9 @@ ShardedEmbeddingTable, BaseGroupedFeatureProcessor, SparseFeaturesList, + ListOfSparseFeaturesList, ) -from torchrec.distributed.types import Awaitable +from torchrec.distributed.types import NoWait, Awaitable from torchrec.modules.embedding_configs import ( PoolingType, DataType, @@ -311,10 +313,10 @@ def __init__( stagger: int = 1, ) -> None: super().__init__() - self._id_list_features_all2all = KJTAllToAll( + self._id_list_features_all2all: KJTAllToAll = KJTAllToAll( pg, id_list_features_per_rank, device, stagger ) - self._id_score_list_features_all2all = KJTAllToAll( + self._id_score_list_features_all2all: KJTAllToAll = KJTAllToAll( pg, id_score_list_features_per_rank, device, stagger ) @@ -348,6 +350,52 @@ def forward( ) +class SparseFeaturesOneToAll(nn.Module): + def __init__( + self, + id_list_features_per_rank: List[int], + id_score_list_features_per_rank: List[int], + world_size: int, + ) -> None: + super().__init__() + self._world_size = world_size + self._id_list_features_one2all: KJTOneToAll = KJTOneToAll( + id_list_features_per_rank, + world_size, + ) + self._id_score_list_features_one2all: KJTOneToAll = KJTOneToAll( + id_score_list_features_per_rank, world_size + ) + + def forward( + self, + sparse_features: SparseFeatures, + ) -> Awaitable[SparseFeaturesList]: + return NoWait( + SparseFeaturesList( + [ + SparseFeatures( + id_list_features=id_list_features, + id_score_list_features=id_score_list_features, + ) + for id_list_features, id_score_list_features in zip( + self._id_list_features_one2all.forward( + sparse_features.id_list_features + ).wait() + if sparse_features.id_list_features is not None + else [None] * self._world_size, + self._id_score_list_features_one2all.forward( + sparse_features.id_score_list_features + ).wait() + if sparse_features.id_score_list_features is not None + else [None] * self._world_size, + ) + ] + ) + ) + + +# group tables by DataType, PoolingType, Weighted, and EmbeddingComputeKernel. def group_tables( tables_per_rank: List[List[ShardedEmbeddingTable]], ) -> Tuple[List[List[GroupedEmbeddingConfig]], List[List[GroupedEmbeddingConfig]]]: @@ -442,6 +490,14 @@ def _group_tables_per_rank( ) +F = TypeVar("F", bound=Multistreamable) +T = TypeVar("T") +TRAIN_F = TypeVar("TRAIN_F", bound=Multistreamable) +INFER_F = TypeVar("INFER_F", bound=Multistreamable) +TRAIN_T = TypeVar("TRAIN_T") +INFER_T = TypeVar("INFER_T") + + class SparseFeaturesListAwaitable(Awaitable[SparseFeaturesList]): """ Awaitable of SparseFeaturesList. @@ -496,7 +552,35 @@ def _wait_impl(self) -> List[Awaitable[SparseFeatures]]: return [m.wait() for m in self.awaitables] -class BaseSparseFeaturesDist(abc.ABC, nn.Module): +class ListOfSparseFeaturesListAwaitable(Awaitable[ListOfSparseFeaturesList]): + """ + This module handles the tables-wise sharding input features distribution for inference. + For inference, we currently do not separate lengths from indices. + + Constructor Args: + awaitables: List[Awaitable[SparseFeaturesList]] + + """ + + def __init__( + self, + awaitables: List[Awaitable[SparseFeaturesList]], + ) -> None: + super().__init__() + self.awaitables = awaitables + + def _wait_impl(self) -> ListOfSparseFeaturesList: + """ + Syncs sparse features in List of SparseFeaturesList. + + Returns: + ListOfSparseFeaturesList: synced ListOfSparseFeaturesList. + + """ + return ListOfSparseFeaturesList([w.wait() for w in self.awaitables]) + + +class BaseSparseFeaturesDist(abc.ABC, nn.Module, Generic[F]): """ Converts input from data-parallel to model-parallel. """ @@ -505,17 +589,17 @@ class BaseSparseFeaturesDist(abc.ABC, nn.Module): def forward( self, sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeatures]]: + ) -> Awaitable[Awaitable[F]]: pass -class BasePooledEmbeddingDist(abc.ABC, nn.Module): +class BasePooledEmbeddingDist(abc.ABC, nn.Module, Generic[T]): """ Converts output of pooled EmbeddingLookup from model-parallel to data-parallel. """ @abc.abstractmethod - def forward(self, local_embs: torch.Tensor) -> Awaitable[torch.Tensor]: + def forward(self, local_embs: T) -> Awaitable[torch.Tensor]: pass @@ -533,7 +617,7 @@ def forward( pass -class EmbeddingSharding(abc.ABC): +class EmbeddingSharding(abc.ABC, Generic[TRAIN_F, TRAIN_T, INFER_F, INFER_T]): """ Used to implement different sharding types for EmbeddingBagCollection, e.g. table_wise. @@ -543,25 +627,44 @@ def __init__(self, permute_embeddings: bool = False) -> None: self._permute_embeddings: bool = permute_embeddings @abc.abstractmethod - def create_input_dist(self) -> BaseSparseFeaturesDist: + def create_train_input_dist(self) -> BaseSparseFeaturesDist[TRAIN_F]: pass @abc.abstractmethod - def create_pooled_output_dist(self) -> BasePooledEmbeddingDist: + def create_train_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BasePooledEmbeddingDist[TRAIN_T]: pass @abc.abstractmethod - def create_sequence_output_dist(self) -> BaseSequenceEmbeddingDist: + def create_train_sequence_output_dist(self) -> BaseSequenceEmbeddingDist: pass @abc.abstractmethod - def create_lookup( + def create_train_lookup( self, fused_params: Optional[Dict[str, Any]], feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup: + ) -> BaseEmbeddingLookup[TRAIN_F, TRAIN_T]: pass + def create_infer_input_dist(self) -> BaseSparseFeaturesDist[INFER_F]: + raise NotImplementedError + + def create_infer_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BasePooledEmbeddingDist[INFER_T]: + raise NotImplementedError + + def create_infer_lookup( + self, + fused_params: Optional[Dict[str, Any]], + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup[INFER_F, INFER_T]: + raise NotImplementedError + @abc.abstractmethod def embedding_dims(self) -> List[int]: pass diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 694e9c564..f184927f2 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -8,7 +8,7 @@ import abc from dataclasses import dataclass from enum import Enum, unique -from typing import List, Optional, Dict, Any, TypeVar, Iterator +from typing import Generic, List, Optional, Dict, Any, TypeVar, Iterator import torch from torch import nn @@ -83,6 +83,27 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None: feature.record_stream(stream) +class ListOfSparseFeaturesList(Multistreamable): + def __init__(self, features: List[SparseFeaturesList]) -> None: + self.features_list = features + + def __len__(self) -> int: + return len(self.features_list) + + def __setitem__(self, key: int, item: SparseFeaturesList) -> None: + self.features_list[key] = item + + def __getitem__(self, key: int) -> SparseFeaturesList: + return self.features_list[key] + + def __iter__(self) -> Iterator[SparseFeaturesList]: + return iter(self.features_list) + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + for feature in self.features_list: + feature.record_stream(stream) + + @dataclass class ShardedConfig: local_rows: int = 0 @@ -162,7 +183,11 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: return embedding_shard_metadata -class BaseEmbeddingLookup(abc.ABC, nn.Module): +F = TypeVar("F", bound=Multistreamable) +T = TypeVar("T") + + +class BaseEmbeddingLookup(abc.ABC, nn.Module, Generic[F, T]): """ Interface implemented by different embedding implementations: e.g. one, which relies on nn.EmbeddingBag or table-batched one, etc. @@ -171,8 +196,8 @@ class BaseEmbeddingLookup(abc.ABC, nn.Module): @abc.abstractmethod def forward( self, - sparse_features: SparseFeatures, - ) -> torch.Tensor: + sparse_features: F, + ) -> T: pass def sparse_grad_parameter_names( diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 653d498f1..b4ae3205f 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -40,6 +40,7 @@ EmbeddingComputeKernel, BaseEmbeddingLookup, SparseFeaturesList, + ListOfSparseFeaturesList, ) from torchrec.distributed.rw_sharding import RwEmbeddingSharding from torchrec.distributed.tw_sharding import TwEmbeddingSharding @@ -49,12 +50,10 @@ Awaitable, LazyAwaitable, ParameterSharding, - ParameterStorage, ShardedModule, ShardingType, ShardedModuleContext, ShardedTensor, - ModuleSharder, ShardingEnv, ) from torchrec.distributed.utils import append_prefix @@ -65,12 +64,9 @@ ) from torchrec.optim.fused import FusedOptimizerModule from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer -from torchrec.quant.embedding_modules import ( - EmbeddingBagCollection as QuantEmbeddingBagCollection, -) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor - +# pyre-fixme [3] def create_embedding_sharding( sharding_type: str, embedding_configs: List[ @@ -79,13 +75,13 @@ def create_embedding_sharding( env: ShardingEnv, device: Optional[torch.device] = None, permute_embeddings: bool = False, -) -> EmbeddingSharding: +) -> EmbeddingSharding[Any, Any, Any, Any]: pg = env.process_group if device is not None and device.type == "meta": replace_placement_with_meta_device(embedding_configs) if pg is not None: if sharding_type == ShardingType.TABLE_WISE.value: - return TwEmbeddingSharding(embedding_configs, pg, device) + return TwEmbeddingSharding(embedding_configs, env, device) elif sharding_type == ShardingType.ROW_WISE.value: return RwEmbeddingSharding(embedding_configs, pg, device) elif sharding_type == ShardingType.DATA_PARALLEL.value: @@ -94,17 +90,19 @@ def create_embedding_sharding( return TwRwEmbeddingSharding(embedding_configs, pg, device) elif sharding_type == ShardingType.COLUMN_WISE.value: return CwEmbeddingSharding( - embedding_configs, pg, device, permute_embeddings=permute_embeddings + embedding_configs, env, device, permute_embeddings=permute_embeddings ) elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value: return TwCwEmbeddingSharding( - embedding_configs, pg, device, permute_embeddings=permute_embeddings + embedding_configs, env, device, permute_embeddings=permute_embeddings ) else: raise ValueError(f"Sharding type not supported {sharding_type}") else: if sharding_type == ShardingType.DATA_PARALLEL.value: return DpEmbeddingSharding(embedding_configs, env, device) + elif sharding_type == ShardingType.TABLE_WISE.value: + return TwEmbeddingSharding(embedding_configs, env, device) else: raise ValueError(f"Sharding type not supported {sharding_type}") @@ -239,13 +237,16 @@ def _wait_impl(self) -> KeyedTensor: ) -class ShardedEmbeddingBagCollection( +F = TypeVar("F", SparseFeaturesList, ListOfSparseFeaturesList) +T = TypeVar("T", List[torch.Tensor], List[List[torch.Tensor]]) + + +class ShardedEmbeddingBagCollectionBase( ShardedModule[ - SparseFeaturesList, - List[torch.Tensor], + F, + T, KeyedTensor, ], - FusedOptimizerModule, ): """ Sharded implementation of EmbeddingBagCollection. @@ -264,7 +265,10 @@ def __init__( sharding_type_to_embedding_configs = _create_embedding_configs_by_sharding( module, table_name_to_parameter_sharding, "embedding_bags." ) - self._sharding_type_to_sharding: Dict[str, EmbeddingSharding] = { + # pyre-fixme[4] + self._sharding_type_to_sharding: Dict[ + str, EmbeddingSharding[Any, Any, Any, Any] + ] = { sharding_type: create_embedding_sharding( sharding_type, embedding_confings, env, device, permute_embeddings=True ) @@ -274,11 +278,13 @@ def __init__( self._is_weighted: bool = module.is_weighted self._device = device self._create_lookups(fused_params) + # pyre-fixme[24]: Non-generic type `nn.modules.container.ModuleList` cannot # take parameters. self._output_dists: nn.ModuleList[nn.Module] = nn.ModuleList() self._embedding_names: List[str] = [] self._embedding_dims: List[int] = [] + # pyre-fixme[24]: Non-generic type `nn.modules.container.ModuleList` cannot # take parameters. self._input_dists: nn.ModuleList[nn.Module] = nn.ModuleList() @@ -290,28 +296,14 @@ def __init__( self._has_uninitialized_output_dist: bool = True self._has_features_permute: bool = True - # Get all fused optimizers and combine them. - optims = [] - for lookup in self._lookups: - for _, module in lookup.named_modules(): - if isinstance(module, FusedOptimizerModule): - # modify param keys to match EmbeddingBagCollection - params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} - for param_key, weight in module.fused_optimizer.params.items(): - # pyre-fixme[16]: `Mapping` has no attribute `__setitem__`. - params["embedding_bags." + param_key] = weight - module.fused_optimizer.params = params - optims.append(("", module.fused_optimizer)) - self._optim: CombinedOptimizer = CombinedOptimizer(optims) - def _create_input_dist( self, input_feature_names: List[str], + device: torch.device, ) -> None: - feature_names: List[str] = [] for sharding in self._sharding_type_to_sharding.values(): - self._input_dists.append(sharding.create_input_dist()) + self._input_dists.append(sharding.create_train_input_dist()) feature_names.extend( sharding.id_score_list_feature_names() if self._is_weighted @@ -332,33 +324,31 @@ def _create_input_dist( self._features_order.append(input_feature_names.index(f)) self.register_buffer( "_features_order_tensor", - torch.tensor( - self._features_order, device=self._device, dtype=torch.int32 - ), + torch.tensor(self._features_order, device=device, dtype=torch.int32), ) def _create_lookups( self, fused_params: Optional[Dict[str, Any]], ) -> None: - # pyre-fixme[24]: Non-generic type `nn.modules.container.ModuleList` cannot - # take parameters. - self._lookups: nn.ModuleList[BaseEmbeddingLookup] = nn.ModuleList() + # pyre-fixme[24] + self._lookups: nn.ModuleList[BaseEmbeddingLookup[Any, Any]] = nn.ModuleList() for sharding in self._sharding_type_to_sharding.values(): - self._lookups.append(sharding.create_lookup(fused_params)) + self._lookups.append(sharding.create_train_lookup(fused_params)) - def _create_output_dist(self) -> None: + def _create_output_dist(self, device: Optional[torch.device] = None) -> None: for sharding in self._sharding_type_to_sharding.values(): - self._output_dists.append(sharding.create_pooled_output_dist()) + self._output_dists.append(sharding.create_train_pooled_output_dist(device)) self._embedding_names.extend(sharding.embedding_names()) self._embedding_dims.extend(sharding.embedding_dims()) + # pyre-ignore [3] # pyre-ignore [14] def input_dist( self, ctx: ShardedModuleContext, features: KeyedJaggedTensor - ) -> Awaitable[SparseFeaturesList]: + ) -> Awaitable[Any]: if self._has_uninitialized_input_dist: - self._create_input_dist(features.keys()) + self._create_input_dist(features.keys(), features.device()) self._has_uninitialized_input_dist = False with torch.no_grad(): if self._has_features_permute: @@ -386,15 +376,20 @@ def input_dist( return SparseFeaturesListAwaitable(awaitables) def compute( - self, ctx: ShardedModuleContext, dist_input: SparseFeaturesList - ) -> List[torch.Tensor]: + self, + ctx: ShardedModuleContext, + dist_input: F, + ) -> T: + # pyre-fixme [7] return [lookup(features) for lookup, features in zip(self._lookups, dist_input)] def output_dist( - self, ctx: ShardedModuleContext, output: List[torch.Tensor] + self, + ctx: ShardedModuleContext, + output: T, ) -> LazyAwaitable[KeyedTensor]: if self._has_uninitialized_output_dist: - self._create_output_dist() + self._create_output_dist(self._device) self._has_uninitialized_output_dist = False return EmbeddingCollectionAwaitable( awaitables=[ @@ -405,16 +400,18 @@ def output_dist( ) def compute_and_output_dist( - self, ctx: ShardedModuleContext, input: SparseFeaturesList + self, ctx: ShardedModuleContext, input: F ) -> LazyAwaitable[KeyedTensor]: if self._has_uninitialized_output_dist: - self._create_output_dist() + self._create_output_dist(self._device) self._has_uninitialized_output_dist = False return EmbeddingCollectionAwaitable( awaitables=[ dist(lookup(features)) for lookup, dist, features in zip( - self._lookups, self._output_dists, input + self._lookups, + self._output_dists, + input, ) ], embedding_dims=self._embedding_dims, @@ -500,6 +497,41 @@ def sparse_grad_parameter_names( ) return destination + +class ShardedEmbeddingBagCollection( + ShardedEmbeddingBagCollectionBase[F, T], + FusedOptimizerModule, +): + """ + Sharded implementation of EmbeddingBagCollection. + This is part of public API to allow for manual data dist pipelining. + """ + + def __init__( + self, + module: EmbeddingBagCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__( + module, table_name_to_parameter_sharding, env, fused_params, device + ) + # Get all fused optimizers and combine them. + optims = [] + for lookup in self._lookups: + for _, module in lookup.named_modules(): + if isinstance(module, FusedOptimizerModule): + # modify param keys to match EmbeddingBagCollection + params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} + for param_key, weight in module.fused_optimizer.params.items(): + # pyre-fixme[16]: `Mapping` has no attribute `__setitem__`. + params["embedding_bags." + param_key] = weight + module.fused_optimizer.params = params + optims.append(("", module.fused_optimizer)) + self._optim: CombinedOptimizer = CombinedOptimizer(optims) + @property def fused_optimizer(self) -> KeyedOptimizer: return self._optim @@ -519,7 +551,7 @@ def shard( params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, - ) -> ShardedEmbeddingBagCollection: + ) -> ShardedEmbeddingBagCollection[SparseFeaturesList, List[torch.Tensor]]: return ShardedEmbeddingBagCollection( module, params, env, self.fused_params, device ) @@ -537,48 +569,6 @@ def module_type(self) -> Type[EmbeddingBagCollection]: return EmbeddingBagCollection -class QuantEmbeddingBagCollectionSharder(ModuleSharder[QuantEmbeddingBagCollection]): - def shard( - self, - module: QuantEmbeddingBagCollection, - params: Dict[str, ParameterSharding], - env: ShardingEnv, - device: Optional[torch.device] = None, - ) -> ShardedEmbeddingBagCollection: - return ShardedEmbeddingBagCollection(module, params, env, None, device) - - def sharding_types(self, compute_device_type: str) -> List[str]: - return [ShardingType.DATA_PARALLEL.value] - - def compute_kernels( - self, sharding_type: str, compute_device_type: str - ) -> List[str]: - return [ - EmbeddingComputeKernel.BATCHED_QUANT.value, - ] - - def storage_usage( - self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str - ) -> Dict[str, int]: - tensor_bytes = tensor.numel() * tensor.element_size() + tensor.shape[0] * 4 - assert compute_device_type in {"cuda", "cpu"} - storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} - return {storage_map[compute_device_type].value: tensor_bytes} - - def shardable_parameters( - self, module: QuantEmbeddingBagCollection - ) -> Dict[str, nn.Parameter]: - return { - name.split(".")[-2]: param - for name, param in module.state_dict().items() - if name.endswith(".weight") - } - - @property - def module_type(self) -> Type[QuantEmbeddingBagCollection]: - return QuantEmbeddingBagCollection - - class EmbeddingAwaitable(LazyAwaitable[torch.Tensor]): def __init__( self, @@ -644,7 +634,9 @@ def __init__( "table-wise sharding on a single EmbeddingBag is not supported yet" ) - self._embedding_sharding: EmbeddingSharding = create_embedding_sharding( + self._embedding_sharding: EmbeddingSharding[ + SparseFeatures, torch.Tensor, SparseFeaturesList, List[torch.Tensor] + ] = create_embedding_sharding( sharding_type=self.parameter_sharding.sharding_type, embedding_configs=[ ( @@ -657,10 +649,12 @@ def __init__( device=device, permute_embeddings=True, ) - self._input_dist: nn.Module = self._embedding_sharding.create_input_dist() - self._lookup: nn.Module = self._embedding_sharding.create_lookup(fused_params) + self._input_dist: nn.Module = self._embedding_sharding.create_train_input_dist() + self._lookup: nn.Module = self._embedding_sharding.create_train_lookup( + fused_params + ) self._output_dist: nn.Module = ( - self._embedding_sharding.create_pooled_output_dist() + self._embedding_sharding.create_train_pooled_output_dist() ) # Get all fused optimizers and combine them. diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index aa35ebe27..fafc84a22 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -15,13 +15,13 @@ from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embeddingbag import ( EmbeddingBagCollectionSharder, - QuantEmbeddingBagCollectionSharder, ) from torchrec.distributed.planner import ( EmbeddingShardingPlanner, sharder_name, Topology, ) +from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.types import ( ShardingPlan, ModuleSharder, @@ -74,6 +74,10 @@ def init_weights(m): None """ + SHARE_SHARDED: bool = False + # pyre-fixme [4] + SHARED_SHARDED_MODULE: Dict[str, ShardedModule[Any, Any, Any]] = {} + def __init__( self, module: nn.Module, @@ -175,15 +179,33 @@ def _shard_modules_impl( for name, child in module.named_children(): curr_path = path + name sharded_params = self._plan.get_plan_for_module(curr_path) + sharder_key = sharder_name(type(child)) if sharded_params: - # Shard module - sharder_key = sharder_name(type(child)) - sharded_child = self._sharder_map[sharder_key].shard( - child, - sharded_params, - self._env, - self.device, - ) + if DistributedModelParallel.SHARE_SHARDED: + if name in DistributedModelParallel.SHARED_SHARDED_MODULE: + sharded_child = DistributedModelParallel.SHARED_SHARDED_MODULE[ + name + ] + else: + # Shard module device-agnostic + # This is the multi-threading programming model case + sharded_child = self._sharder_map[sharder_key].shard( + child, + sharded_params, + self._env, + self.device, + ) + DistributedModelParallel.SHARED_SHARDED_MODULE[ + name + ] = sharded_child + else: + # Shard module + sharded_child = self._sharder_map[sharder_key].shard( + child, + sharded_params, + self._env, + self.device, + ) setattr(module, name, sharded_child) if isinstance(sharded_child, FusedOptimizerModule): fused_optims.append((curr_path, sharded_child.fused_optimizer)) diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py new file mode 100644 index 000000000..7e0025e6a --- /dev/null +++ b/torchrec/distributed/quant_embeddingbag.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Dict, Optional, Type, Any + +import torch +from torch import nn +from torchrec.distributed.embedding_sharding import ( + ListOfSparseFeaturesListAwaitable, +) +from torchrec.distributed.embedding_types import ( + SparseFeatures, + EmbeddingComputeKernel, + ListOfSparseFeaturesList, +) +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollectionBase +from torchrec.distributed.types import ( + Awaitable, + ParameterSharding, + ParameterStorage, + ShardingType, + ShardedModuleContext, + ModuleSharder, + ShardingEnv, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollectionInterface, +) +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class ShardedQuantEmbeddingBagCollection( + ShardedEmbeddingBagCollectionBase[ + ListOfSparseFeaturesList, + List[List[torch.Tensor]], + ], +): + """ + Sharded implementation of EmbeddingBagCollection. + This is part of public API to allow for manual data dist pipelining. + """ + + def __init__( + self, + module: EmbeddingBagCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__( + module, table_name_to_parameter_sharding, env, fused_params, device + ) + + # pyre-ignore [3] + def input_dist( + self, ctx: ShardedModuleContext, features: KeyedJaggedTensor + ) -> Awaitable[Any]: + if self._has_uninitialized_input_dist: + self._create_input_dist(features.keys(), features.device()) + self._has_uninitialized_input_dist = False + if self._has_uninitialized_output_dist: + self._create_output_dist(features.device()) + self._has_uninitialized_output_dist = False + with torch.no_grad(): + if self._has_features_permute: + features = features.permute( + self._features_order, + # pyre-ignore [6] + self._features_order_tensor, + ) + features_by_shards = features.split( + self._feature_splits, + ) + awaitables = [ + module( + SparseFeatures( + id_list_features=None + if self._is_weighted + else features_by_shard, + id_score_list_features=features_by_shard + if self._is_weighted + else None, + ) + ).wait() # a dummy wait since now length indices comm is splited + for module, features_by_shard in zip( + self._input_dists, features_by_shards + ) + ] + return ListOfSparseFeaturesListAwaitable(awaitables) + + def _create_input_dist( + self, + input_feature_names: List[str], + device: torch.device, + ) -> None: + feature_names: List[str] = [] + for sharding in self._sharding_type_to_sharding.values(): + self._input_dists.append(sharding.create_infer_input_dist()) + feature_names.extend( + sharding.id_score_list_feature_names() + if self._is_weighted + else sharding.id_list_feature_names() + ) + self._feature_splits.append( + len( + sharding.id_score_list_feature_names() + if self._is_weighted + else sharding.id_list_feature_names() + ) + ) + + if feature_names == input_feature_names: + self._has_features_permute = False + else: + for f in feature_names: + self._features_order.append(input_feature_names.index(f)) + self.register_buffer( + "_features_order_tensor", + torch.tensor(self._features_order, device=device, dtype=torch.int32), + ) + + def _create_lookups( + self, + fused_params: Optional[Dict[str, Any]], + ) -> None: + self._lookups = nn.ModuleList() + for sharding in self._sharding_type_to_sharding.values(): + self._lookups.append(sharding.create_infer_lookup(fused_params)) + + def _create_output_dist(self, device: Optional[torch.device] = None) -> None: + for sharding in self._sharding_type_to_sharding.values(): + self._output_dists.append(sharding.create_infer_pooled_output_dist(device)) + self._embedding_names.extend(sharding.embedding_names()) + self._embedding_dims.extend(sharding.embedding_dims()) + + +class QuantEmbeddingBagCollectionSharder(ModuleSharder[QuantEmbeddingBagCollection]): + def shard( + self, + module: QuantEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedQuantEmbeddingBagCollection: + return ShardedQuantEmbeddingBagCollection(module, params, env, None, device) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.DATA_PARALLEL.value, ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [ + EmbeddingComputeKernel.BATCHED_QUANT.value, + ] + + def storage_usage( + self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str + ) -> Dict[str, int]: + tensor_bytes = tensor.numel() * tensor.element_size() + tensor.shape[0] * 4 + assert compute_device_type in {"cuda", "cpu"} + storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} + return {storage_map[compute_device_type].value: tensor_bytes} + + def shardable_parameters( + self, module: QuantEmbeddingBagCollection + ) -> Dict[str, nn.Parameter]: + return { + name.split(".")[-2]: param + for name, param in module.state_dict().items() + if name.endswith(".weight") + } + + @property + def module_type(self) -> Type[QuantEmbeddingBagCollection]: + return QuantEmbeddingBagCollection diff --git a/torchrec/distributed/rw_sharding.py b/torchrec/distributed/rw_sharding.py index 58914c1a6..04e6ec040 100644 --- a/torchrec/distributed/rw_sharding.py +++ b/torchrec/distributed/rw_sharding.py @@ -30,6 +30,7 @@ bucketize_kjt_before_all2all, ) from torchrec.distributed.embedding_types import ( + SparseFeaturesList, ShardedEmbeddingTable, GroupedEmbeddingConfig, SparseFeatures, @@ -44,7 +45,7 @@ from torchrec.modules.embedding_configs import EmbeddingTableConfig -class RwSparseFeaturesDist(BaseSparseFeaturesDist): +class RwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): def __init__( self, # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type. @@ -133,7 +134,7 @@ def forward( return self._dist(bucketized_sparse_features) -class RwPooledEmbeddingDist(BasePooledEmbeddingDist): +class RwPooledEmbeddingDist(BasePooledEmbeddingDist[torch.Tensor]): def __init__( self, pg: dist.ProcessGroup, @@ -167,7 +168,11 @@ def forward( ) -class RwEmbeddingSharding(EmbeddingSharding): +class RwEmbeddingSharding( + EmbeddingSharding[ + SparseFeatures, torch.Tensor, SparseFeaturesList, List[torch.Tensor] + ] +): """ Shards embedding bags row-wise, i.e.. a given embedding table is evenly distributed by rows and table slices are placed on all ranks. @@ -255,7 +260,7 @@ def _shard( ) return tables_per_rank - def create_input_dist(self) -> BaseSparseFeaturesDist: + def create_train_input_dist(self) -> BaseSparseFeaturesDist[SparseFeatures]: num_id_list_features = self._get_id_list_features_num() num_id_score_list_features = self._get_id_score_list_features_num() id_list_feature_hash_sizes = self._get_id_list_features_hash_sizes() @@ -271,7 +276,7 @@ def create_input_dist(self) -> BaseSparseFeaturesDist: has_feature_processor=self._has_feature_processor, ) - def create_lookup( + def create_train_lookup( self, fused_params: Optional[Dict[str, Any]], feature_processor: Optional[BaseGroupedFeatureProcessor] = None, @@ -293,10 +298,13 @@ def create_lookup( feature_processor=feature_processor, ) - def create_pooled_output_dist(self) -> RwPooledEmbeddingDist: + def create_train_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> RwPooledEmbeddingDist: return RwPooledEmbeddingDist(self._pg) - def create_sequence_output_dist(self) -> RwSequenceEmbeddingDist: + def create_train_sequence_output_dist(self) -> RwSequenceEmbeddingDist: return RwSequenceEmbeddingDist( self._pg, self._get_id_list_features_num(), diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py index b04032f65..4972d365c 100644 --- a/torchrec/distributed/tests/test_quant_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -19,10 +19,10 @@ QuantBatchedEmbeddingBag, ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.embeddingbag import ( +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.quant_embeddingbag import ( QuantEmbeddingBagCollectionSharder, ) -from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.tests.test_model import ( TestSparseNN, TestEBCSharder, diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py index c2d6117c1..16daf3284 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -48,7 +48,9 @@ from torchrec.tests.utils import get_free_port, init_distributed_single_host -class TestShardedEmbeddingBagCollection(ShardedEmbeddingBagCollection): +class TestShardedEmbeddingBagCollection( + ShardedEmbeddingBagCollection[SparseFeaturesList, List[torch.Tensor]] +): def input_dist( self, ctx: ShardedModuleContext, diff --git a/torchrec/distributed/tw_sharding.py b/torchrec/distributed/tw_sharding.py index 440b404ad..d1b35ce75 100644 --- a/torchrec/distributed/tw_sharding.py +++ b/torchrec/distributed/tw_sharding.py @@ -11,16 +11,19 @@ import torch.distributed as dist from torch.distributed._sharding_spec import ShardMetadata from torchrec.distributed.dist_data import ( + PooledEmbeddingsAllToOne, PooledEmbeddingsAllToAll, SequenceEmbeddingAllToAll, ) from torchrec.distributed.embedding_lookup import ( GroupedPooledEmbeddingsLookup, GroupedEmbeddingsLookup, + InferGroupedPooledEmbeddingsLookup, ) from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, SparseFeaturesAllToAll, + SparseFeaturesOneToAll, group_tables, BasePooledEmbeddingDist, BaseSequenceEmbeddingDist, @@ -29,6 +32,7 @@ BaseEmbeddingLookup, ) from torchrec.distributed.embedding_types import ( + SparseFeaturesList, GroupedEmbeddingConfig, SparseFeatures, ShardedEmbeddingTable, @@ -36,14 +40,53 @@ BaseGroupedFeatureProcessor, ) from torchrec.distributed.types import ( + ShardingEnv, ShardedTensorMetadata, Awaitable, + NoWait, ParameterSharding, ) from torchrec.modules.embedding_configs import EmbeddingTableConfig -class TwSparseFeaturesDist(BaseSparseFeaturesDist): +class TwInferenceSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeaturesList]): + def __init__( + self, + id_list_features_per_rank: List[int], + id_score_list_features_per_rank: List[int], + world_size: int, + ) -> None: + super().__init__() + self._dist: SparseFeaturesOneToAll = SparseFeaturesOneToAll( + id_list_features_per_rank, + id_score_list_features_per_rank, + world_size, + ) + + def forward( + self, + sparse_features: SparseFeatures, + ) -> Awaitable[Awaitable[SparseFeaturesList]]: + return NoWait(self._dist.forward(sparse_features)) + + +class TwInferencePooledEmbeddingDist(BasePooledEmbeddingDist[List[torch.Tensor]]): + def __init__( + self, + device: torch.device, + world_size: int, + ) -> None: + super().__init__() + self._dist: PooledEmbeddingsAllToOne = PooledEmbeddingsAllToOne( + device, + world_size, + ) + + def forward(self, local_embs: List[torch.Tensor]) -> Awaitable[torch.Tensor]: + return self._dist.forward(local_embs) + + +class TwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): def __init__( self, # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type. @@ -67,7 +110,7 @@ def forward( return self._dist(sparse_features) -class TwPooledEmbeddingDist(BasePooledEmbeddingDist): +class TwPooledEmbeddingDist(BasePooledEmbeddingDist[torch.Tensor]): def __init__( self, pg: dist.ProcessGroup, @@ -104,7 +147,11 @@ def forward( ) -class TwEmbeddingSharding(EmbeddingSharding): +class TwEmbeddingSharding( + EmbeddingSharding[ + SparseFeatures, torch.Tensor, SparseFeaturesList, List[torch.Tensor] + ] +): """ Shards embedding bags table-wise, i.e.. a given embedding table is entirely placed on a selected rank. @@ -115,16 +162,18 @@ def __init__( embedding_configs: List[ Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] ], - pg: dist.ProcessGroup, + env: ShardingEnv, device: Optional[torch.device] = None, is_sequence: bool = False, permute_embeddings: bool = False, ) -> None: super().__init__(permute_embeddings) - # pyre-fixme[4]: Attribute must be annotated. - self._pg = pg + self._env = env self._device = device self._is_sequence = is_sequence + self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._world_size: int = self._env.world_size + self._rank: int = self._env.rank sharded_tables_per_rank = self._shard(embedding_configs) self._grouped_embedding_configs_per_rank: List[ List[GroupedEmbeddingConfig] @@ -138,10 +187,10 @@ def __init__( ) = group_tables(sharded_tables_per_rank) self._grouped_embedding_configs: List[ GroupedEmbeddingConfig - ] = self._grouped_embedding_configs_per_rank[dist.get_rank(pg)] + ] = self._grouped_embedding_configs_per_rank[self._rank] self._score_grouped_embedding_configs: List[ GroupedEmbeddingConfig - ] = self._score_grouped_embedding_configs_per_rank[dist.get_rank(pg)] + ] = self._score_grouped_embedding_configs_per_rank[self._rank] def _shard( self, @@ -149,7 +198,7 @@ def _shard( Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] ], ) -> List[List[ShardedEmbeddingTable]]: - world_size = self._pg.size() + world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ [] for i in range(world_size) ] @@ -185,7 +234,7 @@ def _shard( ) return tables_per_rank - def create_input_dist(self) -> BaseSparseFeaturesDist: + def create_train_input_dist(self) -> BaseSparseFeaturesDist[SparseFeatures]: return TwSparseFeaturesDist( self._pg, self._id_list_features_per_rank(), @@ -193,7 +242,7 @@ def create_input_dist(self) -> BaseSparseFeaturesDist: self._device, ) - def create_lookup( + def create_train_lookup( self, fused_params: Optional[Dict[str, Any]], feature_processor: Optional[BaseGroupedFeatureProcessor] = None, @@ -215,14 +264,17 @@ def create_lookup( feature_processor=feature_processor, ) - def create_pooled_output_dist(self) -> TwPooledEmbeddingDist: + def create_train_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BasePooledEmbeddingDist[torch.Tensor]: return TwPooledEmbeddingDist( self._pg, self._dim_sum_per_rank(), self._device, ) - def create_sequence_output_dist( + def create_train_sequence_output_dist( self, ) -> BaseSequenceEmbeddingDist: return TwSequenceEmbeddingDist( @@ -231,6 +283,35 @@ def create_sequence_output_dist( self._device, ) + def create_infer_input_dist(self) -> BaseSparseFeaturesDist[SparseFeaturesList]: + return TwInferenceSparseFeaturesDist( + self._id_list_features_per_rank(), + self._id_score_list_features_per_rank(), + self._world_size, + ) + + def create_infer_lookup( + self, + fused_params: Optional[Dict[str, Any]], + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup[SparseFeaturesList, List[torch.Tensor]]: + return InferGroupedPooledEmbeddingsLookup( + grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, + grouped_score_configs_per_rank=self._score_grouped_embedding_configs_per_rank, + world_size=self._world_size, + fused_params=fused_params, + ) + + def create_infer_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> TwInferencePooledEmbeddingDist: + return TwInferencePooledEmbeddingDist( + # pyre-fixme [6] + device, + self._world_size, + ) + def _dim_sum_per_rank(self) -> List[int]: dim_sum_per_rank = [] for grouped_embedding_configs, score_grouped_embedding_configs in zip( diff --git a/torchrec/distributed/twrw_sharding.py b/torchrec/distributed/twrw_sharding.py index 5c92e99f8..995843cc2 100644 --- a/torchrec/distributed/twrw_sharding.py +++ b/torchrec/distributed/twrw_sharding.py @@ -32,6 +32,7 @@ bucketize_kjt_before_all2all, ) from torchrec.distributed.embedding_types import ( + SparseFeaturesList, GroupedEmbeddingConfig, SparseFeatures, ShardedEmbeddingTable, @@ -46,7 +47,7 @@ from torchrec.modules.embedding_configs import EmbeddingTableConfig -class TwRwSparseFeaturesDist(BaseSparseFeaturesDist): +class TwRwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): """ Bucketizes sparse features in TWRW fashion and then redistributes with to AlltoAll collective operation. @@ -237,7 +238,7 @@ def _staggered_shuffle(self, features_per_rank: List[int]) -> List[int]: ] -class TwRwEmbeddingDist(BasePooledEmbeddingDist): +class TwRwEmbeddingDist(BasePooledEmbeddingDist[torch.Tensor]): """ Redistributes pooled embedding tensor in TWRW fashion by performing a reduce-scatter operation row wise on the host level and then an AlltoAll operation table wise on @@ -283,7 +284,11 @@ def forward(self, local_embs: torch.Tensor) -> Awaitable[torch.Tensor]: return self._cross_dist(self._intra_dist(local_embs).wait()) -class TwRwEmbeddingSharding(EmbeddingSharding): +class TwRwEmbeddingSharding( + EmbeddingSharding[ + SparseFeatures, torch.Tensor, SparseFeaturesList, List[torch.Tensor] + ] +): """ Shards embedding bags table-wise then row-wise. """ @@ -400,7 +405,7 @@ def _shard( return tables_per_rank - def create_input_dist(self) -> BaseSparseFeaturesDist: + def create_train_input_dist(self) -> BaseSparseFeaturesDist[SparseFeatures]: num_id_list_features = self._get_id_list_features_num() num_id_score_list_features = self._get_id_score_list_features_num() id_list_features_per_rank = self._features_per_rank( @@ -424,7 +429,7 @@ def create_input_dist(self) -> BaseSparseFeaturesDist: has_feature_processor=self._has_feature_processor, ) - def create_lookup( + def create_train_lookup( self, fused_params: Optional[Dict[str, Any]], feature_processor: Optional[BaseGroupedFeatureProcessor] = None, @@ -440,7 +445,10 @@ def create_lookup( feature_processor=feature_processor, ) - def create_pooled_output_dist(self) -> BasePooledEmbeddingDist: + def create_train_pooled_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BasePooledEmbeddingDist[torch.Tensor]: return TwRwEmbeddingDist( cross_pg=cast(dist.ProcessGroup, self._cross_pg), intra_pg=cast(dist.ProcessGroup, self._intra_pg), @@ -448,7 +456,7 @@ def create_pooled_output_dist(self) -> BasePooledEmbeddingDist: device=self._device, ) - def create_sequence_output_dist(self) -> BaseSequenceEmbeddingDist: + def create_train_sequence_output_dist(self) -> BaseSequenceEmbeddingDist: raise NotImplementedError def embedding_dims(self) -> List[int]: diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 97a4624d8..c07024f10 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -439,7 +439,7 @@ def shard( module: M, params: Dict[str, ParameterSharding], env: ShardingEnv, - device: torch.device, + device: Optional[torch.device] = None, ) -> ShardedModule[Any, Any, Any]: """ Does the actual sharding. It will allocate parameters on the requested locations