diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 017025796..332e247fa 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -60,6 +60,7 @@ from torchrec.distributed.types import ( Awaitable, LazyAwaitable, + NullShardedModuleContext, ParameterSharding, QuantizedCommCodecs, ShardedModule, @@ -1292,8 +1293,9 @@ def _create_input_dists( # pyre-ignore def input_dist( self, - ctx: ManagedCollisionCollectionContext, + ctx: Union[ManagedCollisionCollectionContext, NullShardedModuleContext], features: KeyedJaggedTensor, + is_sequence_embedding: bool = True, ) -> ListOfKJTList: if self._has_uninitialized_input_dists: self._create_input_dists( @@ -1345,19 +1347,20 @@ def input_dist( for feature_split, input_dist in zip(feature_splits, self._input_dists): out = input_dist(feature_split) input_dist_result_list.append(out.features) - ctx.sharding_contexts.append( - InferSequenceShardingContext( - features=out.features, - features_before_input_dist=features, - unbucketize_permute_tensor=( - out.unbucketize_permute_tensor - if isinstance(input_dist, InferRwSparseFeaturesDist) - else None - ), - bucket_mapping_tensor=out.bucket_mapping_tensor, - bucketized_length=out.bucketized_length, + if is_sequence_embedding: + ctx.sharding_contexts.append( + InferSequenceShardingContext( + features=out.features, + features_before_input_dist=features, + unbucketize_permute_tensor=( + out.unbucketize_permute_tensor + if isinstance(input_dist, InferRwSparseFeaturesDist) + else None + ), + bucket_mapping_tensor=out.bucket_mapping_tensor, + bucketized_length=out.bucketized_length, + ) ) - ) return ListOfKJTList(input_dist_result_list) diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index 7c061a0ae..eb5f8ea21 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -8,7 +8,7 @@ # pyre-strict import copy -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union import torch from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( @@ -43,6 +43,11 @@ is_fused_param_register_tbe, ) from torchrec.distributed.global_settings import get_propogate_device +from torchrec.distributed.mc_modules import ( + InferManagedCollisionCollectionSharder, + ShardedMCCRemapper, + ShardedQuantManagedCollisionCollection, +) from torchrec.distributed.quant_state import ShardedQuantEmbeddingModuleState from torchrec.distributed.sharding.cw_sharding import InferCwPooledEmbeddingSharding from torchrec.distributed.sharding.rw_sharding import InferRwPooledEmbeddingSharding @@ -54,7 +59,7 @@ ShardingEnv, ShardingType, ) -from torchrec.distributed.utils import copy_to_device +from torchrec.distributed.utils import append_prefix, copy_to_device from torchrec.modules.embedding_configs import ( data_type_to_sparse_type, dtype_to_data_type, @@ -67,8 +72,9 @@ EmbeddingBagCollection as QuantEmbeddingBagCollection, FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + QuantManagedCollisionEmbeddingBagCollection, ) -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor def get_device_from_parameter_sharding( @@ -722,3 +728,261 @@ def forward(self, features: KeyedJaggedTensor) -> ListOfKJTList: for i in range(len(self._input_dists)) ] ) + + +class ShardedMCEBCLookup(torch.nn.Module): + """ + This module implements distributed compute of a ShardedQuantManagedCollisionEmbeddingBagCollection. + + Args: + sharding (int): sharding index + rank (int): rank index + mcc_remapper (ShardedMCCRemapper): managed collision collection remapper + ebc_lookup (nn.Module): embedding bag collection lookup + + Example:: + + """ + + def __init__( + self, + sharding: int, + rank: int, + mcc_remapper: ShardedMCCRemapper, + ebc_lookup: nn.Module, + ) -> None: + super().__init__() + self._sharding = sharding + self._rank = rank + self._mcc_remapper = mcc_remapper + self._ebc_lookup = ebc_lookup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + """ + Applies managed collision collection remapping and performs embedding lookup. + + Args: + features (KeyedJaggedTensor): input features + + Returns: + Tuple[Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]]: embedding lookup result + """ + remapped_kjt = self._mcc_remapper(features) + return self._ebc_lookup(remapped_kjt) + + +class ShardedQuantManagedCollisionEmbeddingBagCollection( + ShardedQuantEmbeddingBagCollection +): + def __init__( + self, + module: QuantManagedCollisionEmbeddingBagCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + mc_sharder: InferManagedCollisionCollectionSharder, + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__( + module, table_name_to_parameter_sharding, env, fused_params, device + ) + + self._device = device + self._env = env + + # TODO: This is a hack since _embedding_module doesn't need input + # dist, so eliminating it so all fused a2a will ignore it. + # we're using ec input_dist directly, so this cannot be escaped. + # self._has_uninitialized_input_dist = False + embedding_shardings = list( + self._sharding_type_device_group_to_sharding.values() + ) + + self._managed_collision_collection: ShardedQuantManagedCollisionCollection = ( + mc_sharder.shard( + module._managed_collision_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + # pyre-ignore + embedding_shardings=embedding_shardings, + ) + ) + self._return_remapped_features: bool = module._return_remapped_features + self._create_mcebc_lookups() + + def _create_mcebc_lookups(self) -> None: + mcebc_lookups: List[nn.ModuleList] = [] + mcc_remappers: List[List[ShardedMCCRemapper]] = ( + self._managed_collision_collection.create_mcc_remappers() + ) + for sharding in range( + len(self._managed_collision_collection._embedding_shardings) + ): + ebc_sharding_lookups = self._lookups[sharding] + sharding_mcebc_lookups: List[ShardedMCEBCLookup] = [] + for j, ec_lookup in enumerate( + ebc_sharding_lookups._embedding_lookups_per_rank # pyre-ignore + ): + sharding_mcebc_lookups.append( + ShardedMCEBCLookup( + sharding, + j, + mcc_remappers[sharding][j], + ec_lookup, + ) + ) + mcebc_lookups.append(nn.ModuleList(sharding_mcebc_lookups)) + self._mcebc_lookup: nn.ModuleList = nn.ModuleList(mcebc_lookups) + + def input_dist( + self, + ctx: NullShardedModuleContext, + features: KeyedJaggedTensor, + ) -> ListOfKJTList: + # TODO: resolve incompatiblity with different contexts + if self._has_uninitialized_output_dist: + self._create_output_dist(features.device()) + self._has_uninitialized_output_dist = False + + return self._managed_collision_collection.input_dist( + # pyre-fixme [6] + ctx, + features, + is_sequence_embedding=False, + ) + + def compute( + self, + ctx: NullShardedModuleContext, + dist_input: ListOfKJTList, + ) -> List[List[torch.Tensor]]: + ret: List[List[torch.Tensor]] = [] + for i in range(len(self._managed_collision_collection._embedding_shardings)): + dist_input_i = dist_input[i] + lookups = self._mcebc_lookup[i] + sharding_ret: List[torch.Tensor] = [] + for j, lookup in enumerate(lookups): + rank_ret = lookup( + features=dist_input_i[j], + ) + sharding_ret.append(rank_ret) + ret.append(sharding_ret) + return ret + + # pyre-ignore + def output_dist( + self, + ctx: NullShardedModuleContext, + output: List[List[torch.Tensor]], + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + + # pyre-ignore [6] + ebc_out = super().output_dist(ctx, output) + + kjt_out: Optional[KeyedJaggedTensor] = None + + return ebc_out, kjt_out + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for fqn, _ in self.named_parameters(): + yield append_prefix(prefix, fqn) + for fqn, _ in self.named_buffers(): + yield append_prefix(prefix, fqn) + + +class QuantManagedCollisionEmbeddingBagCollectionSharder( + BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingBagCollection] +): + """ + Sharder for QuantManagedCollisionEmbeddingBagCollection. + + This implementation uses non-fused EmbeddingBagCollection and manages both + embedding bag collection sharding and managed collision collection sharding. + + Args: + e_sharder (QuantEmbeddingBagCollectionSharder): sharder for embedding bag collection + mc_sharder (InferManagedCollisionCollectionSharder): sharder for managed collision collection + + Example:: + + """ + + def __init__( + self, + e_sharder: QuantEmbeddingBagCollectionSharder, + mc_sharder: InferManagedCollisionCollectionSharder, + ) -> None: + super().__init__() + self._e_sharder: QuantEmbeddingBagCollectionSharder = e_sharder + self._mc_sharder: InferManagedCollisionCollectionSharder = mc_sharder + + def shardable_parameters( + self, module: QuantManagedCollisionEmbeddingBagCollection + ) -> Dict[str, torch.nn.Parameter]: + return self._e_sharder.shardable_parameters(module) + + def compute_kernels( + self, + sharding_type: str, + compute_device_type: str, + ) -> List[str]: + return [ + EmbeddingComputeKernel.QUANT.value, + ] + + def sharding_types(self, compute_device_type: str) -> List[str]: + return list( + set.intersection( + set(self._e_sharder.sharding_types(compute_device_type)), + set(self._mc_sharder.sharding_types(compute_device_type)), + ) + ) + + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints + return self._e_sharder.fused_params + + def shard( + self, + module: QuantManagedCollisionEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedQuantManagedCollisionEmbeddingBagCollection: + fused_params = self.fused_params if self.fused_params else {} + fused_params["output_dtype"] = data_type_to_sparse_type( + dtype_to_data_type(module.output_dtype()) + ) + if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params: + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + module, + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + False, + ) + if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params: + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + module, FUSED_PARAM_REGISTER_TBE_BOOL, False + ) + return ShardedQuantManagedCollisionEmbeddingBagCollection( + module, + params, + self._mc_sharder, + env, + fused_params, + device, + ) + + @property + def module_type(self) -> Type[QuantManagedCollisionEmbeddingBagCollection]: + return QuantManagedCollisionEmbeddingBagCollection diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 75e87c013..946a9848a 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -470,11 +470,14 @@ def sharded_tbes_weights_spec( is_sfpebc: bool = ( "ShardedQuantFeatureProcessedEmbeddingBagCollection" in type_name ) + is_sqmcebc: bool = ( + "ShardedQuantManagedCollisionEmbeddingBagCollection" in type_name + ) - if is_sqebc or is_sqec or is_sqmcec or is_sfpebc: + if is_sqebc or is_sqec or is_sqmcec or is_sqebc or is_sqmcebc: assert ( - is_sqec + is_sqebc + is_sqmcec + is_sfpebc == 1 - ), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection and ShardedQuantFeatureProcessedEmbeddingBagCollection are true" + is_sqec + is_sqebc + is_sqmcec + is_sfpebc + is_sqmcebc == 1 + ), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection, ShardedQuantFeatureProcessedEmbeddingBagCollection and ShardedQuantManagedCollisionEmbeddingBagCollection are true" tbes_configs: Dict[ IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig ] = module.tbes_configs() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index ddd9de087..b07c585eb 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -57,6 +57,7 @@ FeatureProcessedEmbeddingBagCollection as OriginalFeatureProcessedEmbeddingBagCollection, ) from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection as OriginalManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection, ) from torchrec.modules.mc_modules import ManagedCollisionCollection @@ -218,6 +219,7 @@ def quantize_state_dict( splits = key.split(".") assert splits[-1] == "weight" table_name = splits[-2] + data_type = table_name_to_data_type[table_name] num_rows = tensor.shape[0] @@ -371,7 +373,7 @@ def __init__( self._device: torch.device = device self._table_name_to_quantized_weights: Optional[ Dict[str, Tuple[Tensor, Tensor]] - ] = None + ] = table_name_to_quantized_weights self.row_alignment = row_alignment self._kjt_to_jt_dict = ComputeKJTToJTDict() @@ -1194,3 +1196,159 @@ def from_float( return_remapped_features=mc_ec._return_remapped_features, cache_features_order=getattr(ec, MODULE_ATTR_CACHE_FEATURES_ORDER, False), ) + + +class QuantManagedCollisionEmbeddingBagCollection(EmbeddingBagCollection): + """ + QuantManagedCollisionEmbeddingBagCollection represents a quantized EBC module and a set of managed collision modules. + The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection. + + Args: + tables (List[EmbeddingBagConfig]): list of embedding bag configs + is_weighted (bool): whether the embedding bag collection is weighted + device (torch.device): device on which the embedding collection will be allocated + output_dtype (torch.dtype): data type of the output embeddings. Defaults to torch.float + table_name_to_quantized_weights (Optional[Dict[str, Tuple[Tensor, Tensor]]]): dictionary mapping table names to their corresponding quantized weights. Defaults to None + register_tbes (bool): whether to register the TBEs in the model. Defaults to False + quant_state_dict_split_scale_bias (bool): whether to split the scale and bias parameters when saving the quantized state dict. Defaults to False + row_alignment (int): alignment of rows in the quantized weights. Defaults to DEFAULT_ROW_ALIGNMENT + managed_collision_collection (Optional[ManagedCollisionCollection]): managed collision collection to use for managing collisions. Defaults to None + return_remapped_features (bool): whether to return the remapped input features in addition to the embeddings. Defaults to False + cache_features_order (bool): whether to cache the features order. Defaults to False + + Example:: + + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + is_weighted: bool, + device: torch.device, + output_dtype: torch.dtype = torch.float, + table_name_to_quantized_weights: Optional[ + Dict[str, Tuple[Tensor, Tensor]] + ] = None, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + row_alignment: int = DEFAULT_ROW_ALIGNMENT, + managed_collision_collection: Optional[ManagedCollisionCollection] = None, + return_remapped_features: bool = False, + cache_features_order: bool = False, + ) -> None: + super().__init__( + tables, + is_weighted, + device, + output_dtype, + table_name_to_quantized_weights, + register_tbes, + quant_state_dict_split_scale_bias, + row_alignment, + cache_features_order, + ) + assert ( + managed_collision_collection + ), "Managed collision collection cannot be None" + self._managed_collision_collection: ManagedCollisionCollection = ( + managed_collision_collection + ) + self._return_remapped_features = return_remapped_features + + assert str(self.embedding_bag_configs()) == str( + self._managed_collision_collection.embedding_configs() + ), "Embedding Bag Collection and Managed Collision Collection must contain the same Embedding Configs" + + # Assuming quantized MCEBC is used in inference only + for ( + managed_collision_module + ) in self._managed_collision_collection._managed_collision_modules.values(): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + managed_collision_module.reset_inference_mode() + + def to( + self, *args: List[Any], **kwargs: Dict[str, Any] + ) -> "QuantManagedCollisionEmbeddingBagCollection": + device, dtype, non_blocking, _ = torch._C._nn._parse_to( + *args, # pyre-ignore + **kwargs, # pyre-ignore + ) + for param in self.parameters(): + if param.device.type != "meta": + param.to(device) + + for buffer in self.buffers(): + if buffer.device.type != "meta": + buffer.to(device) + # Skip device movement and continue with other args + super().to( + dtype=dtype, + non_blocking=non_blocking, + ) + return self + + # pyre-ignore + def forward( + self, + features: KeyedJaggedTensor, + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + features = self._managed_collision_collection(features) + + return (super().forward(features), features) + + def _get_name(self) -> str: + return "QuantManagedCollisionEmbeddingBagCollection" + + @classmethod + # pyre-ignore + def from_float( + cls, + module: OriginalManagedCollisionEmbeddingBagCollection, + return_remapped_features: bool = False, + ) -> "QuantManagedCollisionEmbeddingBagCollection": + mc_ebc = module + ebc = module._embedding_module + + # pyre-ignore[9] + qconfig: torch.quantization.QConfig = module.qconfig + assert hasattr( + module, "qconfig" + ), "QuantManagedCollisionEmbeddingBagCollection input float module must have qconfig defined" + + # pyre-ignore[29] + embedding_bag_configs = copy.deepcopy(ebc.embedding_bag_configs()) + _update_embedding_configs( + cast(List[BaseEmbeddingConfig], embedding_bag_configs), + qconfig, + ) + _update_embedding_configs( + mc_ebc._managed_collision_collection._embedding_configs, + qconfig, + ) + + # pyre-ignore[9] + table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] | None = ( + ebc._table_name_to_quantized_weights + if hasattr(ebc, "_table_name_to_quantized_weights") + else None + ) + device = _get_device(ebc) + return cls( + embedding_bag_configs, + ebc.is_weighted(), + device, + output_dtype=qconfig.activation().dtype, + table_name_to_quantized_weights=table_name_to_quantized_weights, + register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False), + quant_state_dict_split_scale_bias=getattr( + ebc, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ), + row_alignment=getattr( + ebc, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT + ), + managed_collision_collection=mc_ebc._managed_collision_collection, + return_remapped_features=mc_ebc._return_remapped_features, + cache_features_order=getattr(ebc, MODULE_ATTR_CACHE_FEATURES_ORDER, False), + )