diff --git a/torchrec/distributed/mc_embeddingbag.py b/torchrec/distributed/mc_embeddingbag.py index d9d3bc3ad..582d9c38a 100644 --- a/torchrec/distributed/mc_embeddingbag.py +++ b/torchrec/distributed/mc_embeddingbag.py @@ -91,6 +91,9 @@ def __init__( device=device, ) ) + # TODO: This is a hack since _embedding_bag_collection doesn't need input + # dist, so eliminating it so all fused a2a will ignore it. + self._embedding_bag_collection._has_uninitialized_input_dist = False self._managed_collision_collection: ShardedManagedCollisionCollection = mc_sharder.shard( module._managed_collision_collection, table_name_to_parameter_sharding, diff --git a/torchrec/distributed/tests/test_mc_embeddingbag.py b/torchrec/distributed/tests/test_mc_embeddingbag.py index 439cfd343..dba476d08 100644 --- a/torchrec/distributed/tests/test_mc_embeddingbag.py +++ b/torchrec/distributed/tests/test_mc_embeddingbag.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, ShardedManagedCollisionEmbeddingBagCollection, @@ -220,6 +221,22 @@ def _test_sharding_and_remapping( # noqa C901 assert isinstance( sharded_sparse_arch._mc_ebc, ShardedManagedCollisionEmbeddingBagCollection ) + assert isinstance( + sharded_sparse_arch._mc_ebc._embedding_bag_collection, + ShardedEmbeddingBagCollection, + ) + assert ( + sharded_sparse_arch._mc_ebc._embedding_bag_collection._has_uninitialized_input_dist + is False + ) + assert ( + not hasattr( + sharded_sparse_arch._mc_ebc._embedding_bag_collection, "_input_dists" + ) + or len(sharded_sparse_arch._mc_ebc._embedding_bag_collection._input_dists) + == 0 + ) + assert isinstance( sharded_sparse_arch._mc_ebc._managed_collision_collection, ShardedManagedCollisionCollection, diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index 3fdcaf20e..4cacc78ad 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -43,6 +43,11 @@ def apply_mc_method_to_jt_dict( return mc_output +@torch.fx.wrap +def coalesce_feature_dict(features_dict: Dict[str, JaggedTensor]) -> KeyedJaggedTensor: + return KeyedJaggedTensor.from_jt_dict(features_dict) + + class ManagedCollisionModule(nn.Module): """ Abstract base class for ManagedCollisionModule. @@ -190,7 +195,7 @@ def forward( table_to_features=self._table_to_features, managed_collisions=self._managed_collision_modules, ) - return KeyedJaggedTensor.from_jt_dict(features_dict) + return coalesce_feature_dict(features_dict) def evict(self) -> Dict[str, Optional[torch.Tensor]]: evictions: Dict[str, Optional[torch.Tensor]] = {} @@ -818,7 +823,7 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: remapped_features: Dict[str, JaggedTensor] = {} for name, feature in features.items(): - values = feature.values() + values = feature.values().to(torch.int64) remapped_ids = torch.empty_like(values) # compute overlap between incoming IDs and remapping table diff --git a/torchrec/modules/tests/test_mc_embedding_modules.py b/torchrec/modules/tests/test_mc_embedding_modules.py index 49d43e92f..4fdcd4006 100644 --- a/torchrec/modules/tests/test_mc_embedding_modules.py +++ b/torchrec/modules/tests/test_mc_embedding_modules.py @@ -256,6 +256,39 @@ def test_zch_ebc_eval(self) -> None: assert torch.all(remapped_kjt4["f2"].values() == remapped_kjt2["f2"].values()) + def test_mc_collection_traceable(self) -> None: + device = torch.device("cpu") + zch_size = 20 + update_interval = 2 + + embedding_configs = [ + EmbeddingBagConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + mc_modules = { + "t1": cast( + ManagedCollisionModule, + MCHManagedCollisionModule( + zch_size=zch_size, + device=device, + input_hash_size=2 * zch_size, + eviction_interval=update_interval, + eviction_policy=DistanceLFU_EvictionPolicy(), + ), + ), + } + mcc = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + # pyre-ignore[6] + embedding_configs=embedding_configs, + ) + gm: torch.fx.GraphModule = torch.fx.symbolic_trace(mcc) + gm.print_readable() + def test_mch_ebc(self) -> None: device = torch.device("cpu") zch_size = 10