Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from torchrec.distributed.types import (
Awaitable,
LazyAwaitable,
NullShardedModuleContext,
ParameterSharding,
QuantizedCommCodecs,
ShardedModule,
Expand Down Expand Up @@ -1292,8 +1293,9 @@ def _create_input_dists(
# pyre-ignore
def input_dist(
self,
ctx: ManagedCollisionCollectionContext,
ctx: Union[ManagedCollisionCollectionContext, NullShardedModuleContext],
features: KeyedJaggedTensor,
is_sequence_embedding: bool = True,
) -> ListOfKJTList:
if self._has_uninitialized_input_dists:
self._create_input_dists(
Expand Down Expand Up @@ -1345,19 +1347,20 @@ def input_dist(
for feature_split, input_dist in zip(feature_splits, self._input_dists):
out = input_dist(feature_split)
input_dist_result_list.append(out.features)
ctx.sharding_contexts.append(
InferSequenceShardingContext(
features=out.features,
features_before_input_dist=features,
unbucketize_permute_tensor=(
out.unbucketize_permute_tensor
if isinstance(input_dist, InferRwSparseFeaturesDist)
else None
),
bucket_mapping_tensor=out.bucket_mapping_tensor,
bucketized_length=out.bucketized_length,
if is_sequence_embedding:
ctx.sharding_contexts.append(
InferSequenceShardingContext(
features=out.features,
features_before_input_dist=features,
unbucketize_permute_tensor=(
out.unbucketize_permute_tensor
if isinstance(input_dist, InferRwSparseFeaturesDist)
else None
),
bucket_mapping_tensor=out.bucket_mapping_tensor,
bucketized_length=out.bucketized_length,
)
)
)

return ListOfKJTList(input_dist_result_list)

Expand Down
270 changes: 267 additions & 3 deletions torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import copy
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Type, Union

import torch
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
Expand Down Expand Up @@ -43,6 +43,11 @@
is_fused_param_register_tbe,
)
from torchrec.distributed.global_settings import get_propogate_device
from torchrec.distributed.mc_modules import (
InferManagedCollisionCollectionSharder,
ShardedMCCRemapper,
ShardedQuantManagedCollisionCollection,
)
from torchrec.distributed.quant_state import ShardedQuantEmbeddingModuleState
from torchrec.distributed.sharding.cw_sharding import InferCwPooledEmbeddingSharding
from torchrec.distributed.sharding.rw_sharding import InferRwPooledEmbeddingSharding
Expand All @@ -54,7 +59,7 @@
ShardingEnv,
ShardingType,
)
from torchrec.distributed.utils import copy_to_device
from torchrec.distributed.utils import append_prefix, copy_to_device
from torchrec.modules.embedding_configs import (
data_type_to_sparse_type,
dtype_to_data_type,
Expand All @@ -67,8 +72,9 @@
EmbeddingBagCollection as QuantEmbeddingBagCollection,
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
QuantManagedCollisionEmbeddingBagCollection,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor


def get_device_from_parameter_sharding(
Expand Down Expand Up @@ -722,3 +728,261 @@ def forward(self, features: KeyedJaggedTensor) -> ListOfKJTList:
for i in range(len(self._input_dists))
]
)


class ShardedMCEBCLookup(torch.nn.Module):
"""
This module implements distributed compute of a ShardedQuantManagedCollisionEmbeddingBagCollection.

Args:
sharding (int): sharding index
rank (int): rank index
mcc_remapper (ShardedMCCRemapper): managed collision collection remapper
ebc_lookup (nn.Module): embedding bag collection lookup

Example::

"""

def __init__(
self,
sharding: int,
rank: int,
mcc_remapper: ShardedMCCRemapper,
ebc_lookup: nn.Module,
) -> None:
super().__init__()
self._sharding = sharding
self._rank = rank
self._mcc_remapper = mcc_remapper
self._ebc_lookup = ebc_lookup

def forward(
self,
features: KeyedJaggedTensor,
) -> Tuple[
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
]:
"""
Applies managed collision collection remapping and performs embedding lookup.

Args:
features (KeyedJaggedTensor): input features

Returns:
Tuple[Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]]: embedding lookup result
"""
remapped_kjt = self._mcc_remapper(features)
return self._ebc_lookup(remapped_kjt)


class ShardedQuantManagedCollisionEmbeddingBagCollection(
ShardedQuantEmbeddingBagCollection
):
def __init__(
self,
module: QuantManagedCollisionEmbeddingBagCollection,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
mc_sharder: InferManagedCollisionCollectionSharder,
# TODO - maybe we need this to manage unsharded/sharded consistency/state consistency
env: Union[ShardingEnv, Dict[str, ShardingEnv]],
fused_params: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__(
module, table_name_to_parameter_sharding, env, fused_params, device
)

self._device = device
self._env = env

# TODO: This is a hack since _embedding_module doesn't need input
# dist, so eliminating it so all fused a2a will ignore it.
# we're using ec input_dist directly, so this cannot be escaped.
# self._has_uninitialized_input_dist = False
embedding_shardings = list(
self._sharding_type_device_group_to_sharding.values()
)

self._managed_collision_collection: ShardedQuantManagedCollisionCollection = (
mc_sharder.shard(
module._managed_collision_collection,
table_name_to_parameter_sharding,
env=env,
device=device,
# pyre-ignore
embedding_shardings=embedding_shardings,
)
)
self._return_remapped_features: bool = module._return_remapped_features
self._create_mcebc_lookups()

def _create_mcebc_lookups(self) -> None:
mcebc_lookups: List[nn.ModuleList] = []
mcc_remappers: List[List[ShardedMCCRemapper]] = (
self._managed_collision_collection.create_mcc_remappers()
)
for sharding in range(
len(self._managed_collision_collection._embedding_shardings)
):
ebc_sharding_lookups = self._lookups[sharding]
sharding_mcebc_lookups: List[ShardedMCEBCLookup] = []
for j, ec_lookup in enumerate(
ebc_sharding_lookups._embedding_lookups_per_rank # pyre-ignore
):
sharding_mcebc_lookups.append(
ShardedMCEBCLookup(
sharding,
j,
mcc_remappers[sharding][j],
ec_lookup,
)
)
mcebc_lookups.append(nn.ModuleList(sharding_mcebc_lookups))
self._mcebc_lookup: nn.ModuleList = nn.ModuleList(mcebc_lookups)

def input_dist(
self,
ctx: NullShardedModuleContext,
features: KeyedJaggedTensor,
) -> ListOfKJTList:
# TODO: resolve incompatiblity with different contexts
if self._has_uninitialized_output_dist:
self._create_output_dist(features.device())
self._has_uninitialized_output_dist = False

return self._managed_collision_collection.input_dist(
# pyre-fixme [6]
ctx,
features,
is_sequence_embedding=False,
)

def compute(
self,
ctx: NullShardedModuleContext,
dist_input: ListOfKJTList,
) -> List[List[torch.Tensor]]:
ret: List[List[torch.Tensor]] = []
for i in range(len(self._managed_collision_collection._embedding_shardings)):
dist_input_i = dist_input[i]
lookups = self._mcebc_lookup[i]
sharding_ret: List[torch.Tensor] = []
for j, lookup in enumerate(lookups):
rank_ret = lookup(
features=dist_input_i[j],
)
sharding_ret.append(rank_ret)
ret.append(sharding_ret)
return ret

# pyre-ignore
def output_dist(
self,
ctx: NullShardedModuleContext,
output: List[List[torch.Tensor]],
) -> Tuple[
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
]:

# pyre-ignore [6]
ebc_out = super().output_dist(ctx, output)

kjt_out: Optional[KeyedJaggedTensor] = None

return ebc_out, kjt_out

def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
for fqn, _ in self.named_parameters():
yield append_prefix(prefix, fqn)
for fqn, _ in self.named_buffers():
yield append_prefix(prefix, fqn)


class QuantManagedCollisionEmbeddingBagCollectionSharder(
BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingBagCollection]
):
"""
Sharder for QuantManagedCollisionEmbeddingBagCollection.

This implementation uses non-fused EmbeddingBagCollection and manages both
embedding bag collection sharding and managed collision collection sharding.

Args:
e_sharder (QuantEmbeddingBagCollectionSharder): sharder for embedding bag collection
mc_sharder (InferManagedCollisionCollectionSharder): sharder for managed collision collection

Example::

"""

def __init__(
self,
e_sharder: QuantEmbeddingBagCollectionSharder,
mc_sharder: InferManagedCollisionCollectionSharder,
) -> None:
super().__init__()
self._e_sharder: QuantEmbeddingBagCollectionSharder = e_sharder
self._mc_sharder: InferManagedCollisionCollectionSharder = mc_sharder

def shardable_parameters(
self, module: QuantManagedCollisionEmbeddingBagCollection
) -> Dict[str, torch.nn.Parameter]:
return self._e_sharder.shardable_parameters(module)

def compute_kernels(
self,
sharding_type: str,
compute_device_type: str,
) -> List[str]:
return [
EmbeddingComputeKernel.QUANT.value,
]

def sharding_types(self, compute_device_type: str) -> List[str]:
return list(
set.intersection(
set(self._e_sharder.sharding_types(compute_device_type)),
set(self._mc_sharder.sharding_types(compute_device_type)),
)
)

@property
def fused_params(self) -> Optional[Dict[str, Any]]:
# TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints
return self._e_sharder.fused_params

def shard(
self,
module: QuantManagedCollisionEmbeddingBagCollection,
params: Dict[str, ParameterSharding],
env: Union[ShardingEnv, Dict[str, ShardingEnv]],
device: Optional[torch.device] = None,
module_fqn: Optional[str] = None,
) -> ShardedQuantManagedCollisionEmbeddingBagCollection:
fused_params = self.fused_params if self.fused_params else {}
fused_params["output_dtype"] = data_type_to_sparse_type(
dtype_to_data_type(module.output_dtype())
)
if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params:
fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr(
module,
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
False,
)
if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params:
fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr(
module, FUSED_PARAM_REGISTER_TBE_BOOL, False
)
return ShardedQuantManagedCollisionEmbeddingBagCollection(
module,
params,
self._mc_sharder,
env,
fused_params,
device,
)

@property
def module_type(self) -> Type[QuantManagedCollisionEmbeddingBagCollection]:
return QuantManagedCollisionEmbeddingBagCollection
9 changes: 6 additions & 3 deletions torchrec/distributed/quant_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,14 @@ def sharded_tbes_weights_spec(
is_sfpebc: bool = (
"ShardedQuantFeatureProcessedEmbeddingBagCollection" in type_name
)
is_sqmcebc: bool = (
"ShardedQuantManagedCollisionEmbeddingBagCollection" in type_name
)

if is_sqebc or is_sqec or is_sqmcec or is_sfpebc:
if is_sqebc or is_sqec or is_sqmcec or is_sqebc or is_sqmcebc:
assert (
is_sqec + is_sqebc + is_sqmcec + is_sfpebc == 1
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection and ShardedQuantFeatureProcessedEmbeddingBagCollection are true"
is_sqec + is_sqebc + is_sqmcec + is_sfpebc + is_sqmcebc == 1
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection, ShardedQuantFeatureProcessedEmbeddingBagCollection and ShardedQuantManagedCollisionEmbeddingBagCollection are true"
tbes_configs: Dict[
IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig
] = module.tbes_configs()
Expand Down
Loading
Loading