From 33d25d0cc8003ba594bbd17cf88afe0f19036c94 Mon Sep 17 00:00:00 2001 From: Felicity Liao Date: Thu, 24 Apr 2025 16:29:40 -0700 Subject: [PATCH] Add unsharded module property to sharded modules and EBC (#2912) Summary: Adding a simple unsharded module reference to sharded modules. This will be used in Dynamic Sharding by DistributedModelParallel to reshard an already-sharded_module. As DMP is created with only one-way relationship in mind, accessing the unsharded module type will help determine which sharder to use in 'resharding'. See comment under `types.py` Differential Revision: D73537260 --- torchrec/distributed/embedding_types.py | 23 ++++++++++++++++++++++- torchrec/distributed/embeddingbag.py | 8 ++++++++ torchrec/distributed/mc_modules.py | 8 ++++++++ torchrec/distributed/object_pool.py | 6 +++++- torchrec/distributed/types.py | 13 +++++++++++++ 5 files changed, 56 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 3b2b127a0..b66f92f16 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -11,7 +11,18 @@ import copy from dataclasses import dataclass from enum import Enum, unique -from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation @@ -399,6 +410,16 @@ def train(self, mode: bool = True): # pyre-ignore[3] return self + @property + def unsharded_module_type(self) -> Type[nn.Module]: + """ + As this is the generic ShardedEmbeddingModule class, simply + return the generic nn.Module type. In the inherited classes of + ShardedEmbeddingModule, the specific unsharded module type will + be returned. + """ + return nn.Module + M = TypeVar("M", bound=nn.Module) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index d7ae684e3..4dd6f286d 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1627,6 +1627,10 @@ def create_context(self) -> EmbeddingBagCollectionContext: def extend_shard_name(shard_name: str) -> str: return f"embedding_bags.{shard_name}.weight" + @property + def unsharded_module_type(self) -> Type[EmbeddingBagCollection]: + return EmbeddingBagCollection + class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]): """ @@ -1916,6 +1920,10 @@ def fused_optimizer(self) -> KeyedOptimizer: def create_context(self) -> NullShardedModuleContext: return NullShardedModuleContext() + @property + def unsharded_module_type(self) -> Type[nn.EmbeddingBag]: + return nn.EmbeddingBag + class EmbeddingBagSharder(BaseEmbeddingSharder[nn.EmbeddingBag]): """ diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 6126f3ddb..34e4ac672 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -821,6 +821,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for name, _ in module.named_parameters(): yield append_prefix(module_prefix, name) + @property + def unsharded_module_type(self) -> Type[ManagedCollisionCollection]: + return ManagedCollisionCollection + class ManagedCollisionCollectionSharder( BaseEmbeddingSharder[ManagedCollisionCollection] @@ -1303,6 +1307,10 @@ def output_dist( def create_context(self) -> ManagedCollisionCollectionContext: return ManagedCollisionCollectionContext(sharding_contexts=[]) + @property + def unsharded_module_type(self) -> Type[ManagedCollisionCollection]: + return ManagedCollisionCollection + class InferManagedCollisionCollectionSharder(ManagedCollisionCollectionSharder): # pyre-ignore diff --git a/torchrec/distributed/object_pool.py b/torchrec/distributed/object_pool.py index 3ff78bc82..62d5e932c 100644 --- a/torchrec/distributed/object_pool.py +++ b/torchrec/distributed/object_pool.py @@ -8,7 +8,7 @@ # pyre-strict from abc import abstractmethod -from typing import Generic +from typing import Generic, Type import torch from torch._prims_common import is_integer_dtype @@ -144,3 +144,7 @@ def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut: # `None`. def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]: pass + + @property + def unsharded_module_type(self) -> Type[ObjectPool[Out]]: + return ObjectPool[Out] diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 45300f733..59fe483dc 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -1034,6 +1034,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for key, _ in self.named_parameters(prefix): yield key + @property + @abc.abstractmethod + def unsharded_module_type(self) -> Type[nn.Module]: + """ + This property is added as part of dynamic sharding implementation. + + When resharding an already-sharded module wrapped in DMP, the unsharded + module type is needed to identify the proper sharder to reshard. This is + due to DistributedModelParellel (DMP) references module Sharders based + on the unsharded module type. + """ + ... + def get_tensor_size_bytes(t: torch.Tensor) -> int: b: int = t.numel() * t.element_size()