Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@
import copy
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Dict,
Generic,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

import torch
from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation
Expand Down Expand Up @@ -399,6 +410,16 @@ def train(self, mode: bool = True): # pyre-ignore[3]

return self

@property
def unsharded_module_type(self) -> Type[nn.Module]:
"""
As this is the generic ShardedEmbeddingModule class, simply
return the generic nn.Module type. In the inherited classes of
ShardedEmbeddingModule, the specific unsharded module type will
be returned.
"""
return nn.Module


M = TypeVar("M", bound=nn.Module)

Expand Down
8 changes: 8 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,10 @@ def create_context(self) -> EmbeddingBagCollectionContext:
def extend_shard_name(shard_name: str) -> str:
return f"embedding_bags.{shard_name}.weight"

@property
def unsharded_module_type(self) -> Type[EmbeddingBagCollection]:
return EmbeddingBagCollection


class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
"""
Expand Down Expand Up @@ -1916,6 +1920,10 @@ def fused_optimizer(self) -> KeyedOptimizer:
def create_context(self) -> NullShardedModuleContext:
return NullShardedModuleContext()

@property
def unsharded_module_type(self) -> Type[nn.EmbeddingBag]:
return nn.EmbeddingBag


class EmbeddingBagSharder(BaseEmbeddingSharder[nn.EmbeddingBag]):
"""
Expand Down
8 changes: 8 additions & 0 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
for name, _ in module.named_parameters():
yield append_prefix(module_prefix, name)

@property
def unsharded_module_type(self) -> Type[ManagedCollisionCollection]:
return ManagedCollisionCollection


class ManagedCollisionCollectionSharder(
BaseEmbeddingSharder[ManagedCollisionCollection]
Expand Down Expand Up @@ -1303,6 +1307,10 @@ def output_dist(
def create_context(self) -> ManagedCollisionCollectionContext:
return ManagedCollisionCollectionContext(sharding_contexts=[])

@property
def unsharded_module_type(self) -> Type[ManagedCollisionCollection]:
return ManagedCollisionCollection


class InferManagedCollisionCollectionSharder(ManagedCollisionCollectionSharder):
# pyre-ignore
Expand Down
6 changes: 5 additions & 1 deletion torchrec/distributed/object_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

from abc import abstractmethod
from typing import Generic
from typing import Generic, Type

import torch
from torch._prims_common import is_integer_dtype
Expand Down Expand Up @@ -144,3 +144,7 @@ def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut:
# `None`.
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
pass

@property
def unsharded_module_type(self) -> Type[ObjectPool[Out]]:
return ObjectPool[Out]
13 changes: 13 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
for key, _ in self.named_parameters(prefix):
yield key

@property
@abc.abstractmethod
def unsharded_module_type(self) -> Type[nn.Module]:
"""
This property is added as part of dynamic sharding implementation.

When resharding an already-sharded module wrapped in DMP, the unsharded
module type is needed to identify the proper sharder to reshard. This is
due to DistributedModelParellel (DMP) references module Sharders based
on the unsharded module type.
"""
...


def get_tensor_size_bytes(t: torch.Tensor) -> int:
b: int = t.numel() * t.element_size()
Expand Down
Loading