Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions torchrec/distributed/cw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from typing import Set, Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.distributed as dist # noqa
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torchrec.distributed.embedding_types import (
ShardedEmbeddingTable,
EmbeddingComputeKernel,
)
from torchrec.distributed.tw_sharding import TwEmbeddingSharding, TwPooledEmbeddingDist
from torchrec.distributed.types import (
ShardingEnv,
ShardedTensorMetadata,
ShardMetadata,
ParameterSharding,
Expand All @@ -34,13 +35,12 @@ def __init__(
embedding_configs: List[
Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor]
],
# pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
pg: dist.ProcessGroup,
env: ShardingEnv,
device: Optional[torch.device] = None,
permute_embeddings: bool = False,
) -> None:
super().__init__(
embedding_configs, pg, device, permute_embeddings=permute_embeddings
embedding_configs, env, device, permute_embeddings=permute_embeddings
)
if self._permute_embeddings:
self._init_combined_embeddings()
Expand Down Expand Up @@ -162,7 +162,10 @@ def embedding_names(self) -> List[str]:
else super().embedding_names()
)

def create_pooled_output_dist(self) -> TwPooledEmbeddingDist:
def create_train_pooled_output_dist(
self,
device: Optional[torch.device] = None,
) -> TwPooledEmbeddingDist:
embedding_permute_op: Optional[PermutePooledEmbeddings] = None
callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None
if self._permute_embeddings and self._embedding_order != list(
Expand Down
75 changes: 74 additions & 1 deletion torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
alltoall_sequence,
reduce_scatter_pooled,
)
from torchrec.distributed.types import Awaitable
from torchrec.distributed.types import Awaitable, NoWait
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu"
)
except OSError:
pass

Expand Down Expand Up @@ -416,6 +420,54 @@ def forward(
)


class KJTOneToAll(nn.Module):
"""
Redistributes KeyedJaggedTensor to all devices.

Implementation utilizes OnetoAll function, which essentially P2P copies the feature to the devices.

Constructor Args:
splits (List[int]): The lengths of features to split the KeyJaggedTensor features before copying
them.
world_size (int): the number of all devices.
recat (torch.Tensor): recat tensor for reordering tensor order after all2all.

Call Args:
kjt (KeyedJaggedTensor): The input features.

Returns:
Awaitable[List[KeyedJaggedTensor]].
"""
def __init__(
self,
splits: List[int],
world_size: int,
) -> None:
super().__init__()
self._splits = splits
self._world_size = world_size
assert self._world_size == len(splits)

def forward(self, kjt: KeyedJaggedTensor) -> Awaitable[List[KeyedJaggedTensor]]:
"""
Split featuers first and then send the slices to the corresponding devices.


Call Args:
input (KeyedJaggedTensor): KeyedJaggedTensor of values to distribute.

Returns:
Awaitable[List[KeyedJaggedTensor]]: awaitable of the KeyedJaggedTensor splits.

"""
kjts: List[KeyedJaggedTensor] = kjt.split(self._splits)
dist_kjts = [
split_kjt.to(torch.device("cuda", rank), non_blocking=True)
for rank, split_kjt in enumerate(kjts)
]
return NoWait(dist_kjts)


class PooledEmbeddingsAwaitable(Awaitable[torch.Tensor]):
"""
Awaitable for pooled embeddings after collective operation.
Expand Down Expand Up @@ -541,6 +593,27 @@ def callbacks(self) -> List[Callable[[torch.Tensor], torch.Tensor]]:
return self._callbacks


class PooledEmbeddingsAllToOne(nn.Module):
def __init__(
self,
device: torch.device,
world_size: int,
) -> None:
super().__init__()
self._device = device
self._world_size = world_size

def forward(self, tensors: List[torch.Tensor]) -> Awaitable[torch.Tensor]:
assert len(tensors) == self._world_size
return NoWait(
torch.ops.fbgemm.merge_pooled_embeddings(
tensors,
tensors[0].size(0),
self._device,
)
)


class PooledEmbeddingsReduceScatter(nn.Module):
"""
The module class that wraps reduce-scatter communication primitive for pooled
Expand Down
44 changes: 36 additions & 8 deletions torchrec/distributed/dp_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torchrec.modules.embedding_configs import EmbeddingTableConfig


class DpSparseFeaturesDist(BaseSparseFeaturesDist):
class DpSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]):
"""
Distributes sparse features (input) to be data-parallel.
"""
Expand All @@ -58,7 +58,7 @@ def forward(
return NoWait(cast(Awaitable[SparseFeatures], NoWait(sparse_features)))


class DpPooledEmbeddingDist(BasePooledEmbeddingDist):
class DpPooledEmbeddingDist(BasePooledEmbeddingDist[torch.Tensor]):
"""
Distributes pooled embeddings to be data-parallel.
"""
Expand Down Expand Up @@ -104,7 +104,9 @@ def forward(
return NoWait(local_embs)


class DpEmbeddingSharding(EmbeddingSharding):
class DpEmbeddingSharding(
EmbeddingSharding[SparseFeatures, torch.Tensor, SparseFeatures, torch.Tensor]
):
"""
Shards embedding bags using data-parallel, with no table sharding i.e.. a given
embedding table is replicated across all ranks.
Expand All @@ -123,6 +125,8 @@ def __init__(
self._env = env
self._device = device
self._is_sequence = is_sequence
self._rank: int = self._env.rank
self._world_size: int = self._env.world_size
sharded_tables_per_rank = self._shard(embedding_configs)
self._grouped_embedding_configs_per_rank: List[
List[GroupedEmbeddingConfig]
Expand All @@ -147,7 +151,7 @@ def _shard(
Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor]
],
) -> List[List[ShardedEmbeddingTable]]:
world_size = self._env.world_size
world_size = self._world_size
tables_per_rank: List[List[ShardedEmbeddingTable]] = [
[] for i in range(world_size)
]
Expand Down Expand Up @@ -175,10 +179,10 @@ def _shard(
)
return tables_per_rank

def create_input_dist(self) -> DpSparseFeaturesDist:
def create_train_input_dist(self) -> DpSparseFeaturesDist:
return DpSparseFeaturesDist()

def create_lookup(
def create_train_lookup(
self,
fused_params: Optional[Dict[str, Any]],
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
Expand All @@ -200,12 +204,36 @@ def create_lookup(
feature_processor=feature_processor,
)

def create_pooled_output_dist(self) -> DpPooledEmbeddingDist:
def create_train_pooled_output_dist(
self,
device: Optional[torch.device] = None,
) -> DpPooledEmbeddingDist:
return DpPooledEmbeddingDist()

def create_sequence_output_dist(self) -> DpSequenceEmbeddingDist:
def create_train_sequence_output_dist(self) -> DpSequenceEmbeddingDist:
return DpSequenceEmbeddingDist()

def create_infer_input_dist(self) -> DpSparseFeaturesDist:
return DpSparseFeaturesDist()

def create_infer_lookup(
self,
fused_params: Optional[Dict[str, Any]],
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
) -> BaseEmbeddingLookup[SparseFeatures, torch.Tensor]:
return GroupedPooledEmbeddingsLookup(
grouped_configs=self._grouped_embedding_configs,
grouped_score_configs=self._score_grouped_embedding_configs,
fused_params=fused_params,
device=self._device,
)

def create_infer_pooled_output_dist(
self,
device: Optional[torch.device] = None,
) -> DpPooledEmbeddingDist:
return DpPooledEmbeddingDist()

def embedding_dims(self) -> List[int]:
embedding_dims = []
for grouped_config in self._grouped_embedding_configs:
Expand Down
101 changes: 97 additions & 4 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torch.nn.modules.module import _IncompatibleKeys
from torchrec.distributed.embedding_types import (
ShardedEmbeddingTable,
SparseFeaturesList,
GroupedEmbeddingConfig,
BaseEmbeddingLookup,
SparseFeatures,
Expand Down Expand Up @@ -681,7 +682,7 @@ def named_parameters(
)


class GroupedEmbeddingsLookup(BaseEmbeddingLookup):
class GroupedEmbeddingsLookup(BaseEmbeddingLookup[SparseFeatures, torch.Tensor]):
def __init__(
self,
grouped_configs: List[GroupedEmbeddingConfig],
Expand Down Expand Up @@ -1238,6 +1239,7 @@ def __init__(
self._local_rows, config.embedding_tables
)
],
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
)
Expand Down Expand Up @@ -1371,7 +1373,7 @@ def _to_data_type(dtype: torch.dtype) -> DataType:
return ret


class GroupedPooledEmbeddingsLookup(BaseEmbeddingLookup):
class GroupedPooledEmbeddingsLookup(BaseEmbeddingLookup[SparseFeatures, torch.Tensor]):
def __init__(
self,
grouped_configs: List[GroupedEmbeddingConfig],
Expand All @@ -1383,6 +1385,7 @@ def __init__(
) -> None:
def _create_lookup(
config: GroupedEmbeddingConfig,
device: Optional[torch.device] = None,
) -> BaseEmbeddingBag:
if config.compute_kernel == EmbeddingComputeKernel.BATCHED_DENSE:
return BatchedDenseEmbeddingBag(
Expand Down Expand Up @@ -1425,13 +1428,13 @@ def _create_lookup(
# take parameters.
self._emb_modules: nn.ModuleList[BaseEmbeddingBag] = nn.ModuleList()
for config in grouped_configs:
self._emb_modules.append(_create_lookup(config))
self._emb_modules.append(_create_lookup(config, device))

# pyre-fixme[24]: Non-generic type `nn.modules.container.ModuleList` cannot
# take parameters.
self._score_emb_modules: nn.ModuleList[BaseEmbeddingBag] = nn.ModuleList()
for config in grouped_score_configs:
self._score_emb_modules.append(_create_lookup(config))
self._score_emb_modules.append(_create_lookup(config, device))

self._id_list_feature_splits: List[int] = []
for config in grouped_configs:
Expand Down Expand Up @@ -1560,3 +1563,93 @@ def sparse_grad_parameter_names(
for emb_module in self._score_emb_modules:
emb_module.sparse_grad_parameter_names(destination, prefix)
return destination


class InferGroupedPooledEmbeddingsLookup(
BaseEmbeddingLookup[SparseFeaturesList, List[torch.Tensor]]
):
def __init__(
self,
grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]],
grouped_score_configs_per_rank: List[List[GroupedEmbeddingConfig]],
world_size: int,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self._embedding_lookups_per_rank: List[GroupedPooledEmbeddingsLookup] = []
for rank in range(world_size):
self._embedding_lookups_per_rank.append(
GroupedPooledEmbeddingsLookup(
grouped_configs=grouped_configs_per_rank[rank],
grouped_score_configs=grouped_score_configs_per_rank[rank],
fused_params=fused_params,
device=torch.device("cuda", rank),
)
)

def forward(
self,
sparse_features: SparseFeaturesList,
) -> List[torch.Tensor]:
embeddings: List[torch.Tensor] = []
for sparse_features_rank, embedding_lookup in zip(
sparse_features, self._embedding_lookups_per_rank
):
assert (
sparse_features_rank.id_list_features is not None
or sparse_features_rank.id_score_list_features is not None
)
embeddings.append(embedding_lookup(sparse_features_rank))
return embeddings

def state_dict(
self,
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()
# pyre-ignore [16]
destination._metadata = OrderedDict()

for rank_modules in self._embedding_lookups_per_rank:
rank_modules.state_dict(destination, prefix, keep_vars)

return destination

def load_state_dict(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
strict: bool = True,
) -> _IncompatibleKeys:
missing_keys = []
unexpected_keys = []
for rank_modules in self._embedding_lookups_per_rank:
incompatible_keys = rank_modules.load_state_dict(state_dict)
missing_keys.extend(incompatible_keys.missing_keys)
unexpected_keys.extend(incompatible_keys.unexpected_keys)
return _IncompatibleKeys(
missing_keys=missing_keys, unexpected_keys=unexpected_keys
)

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for rank_modules in self._embedding_lookups_per_rank:
yield from rank_modules.named_parameters(prefix, recurse)

def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for rank_modules in self._embedding_lookups_per_rank:
yield from rank_modules.named_buffers(prefix, recurse)

def sparse_grad_parameter_names(
self, destination: Optional[List[str]] = None, prefix: str = ""
) -> List[str]:
destination = [] if destination is None else destination
for rank_modules in self._embedding_lookups_per_rank:
rank_modules.sparse_grad_parameter_names(destination, prefix)
return destination
Loading