Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embeddingbag import (
Expand Down Expand Up @@ -100,6 +101,14 @@ def wrap(
dmp.module._set_static_graph()


def _strip_DDP(module: nn.Module) -> nn.Module:
if isinstance(module, FullyShardedDataParallel) or isinstance(
module, DistributedDataParallel
):
module = module.module
return module


class DistributedModelParallel(nn.Module, FusedOptimizerModule):
"""
Entry point to model parallelism.
Expand Down Expand Up @@ -208,6 +217,7 @@ def dmp_module(self) -> nn.Module:
return (
self.module.module
if isinstance(self.module, DistributedDataParallel)
or isinstance(self.module, FullyShardedDataParallel)
else self.module
)

Expand All @@ -220,7 +230,9 @@ def init_data_parallel(self) -> None:
See init_data_parallel c-tor argument for usage.
It's safe to call this method multiple times.
"""
if not isinstance(self.module, DistributedDataParallel):
if not isinstance(self.module, DistributedDataParallel) and not isinstance(
self.module, FullyShardedDataParallel
):
# Allocate any 'meta' tensors
if self.init_parameters:
self._init_parameters(self.module)
Expand Down Expand Up @@ -316,6 +328,7 @@ def sparse_grad_parameter_names(
def _sparse_grad_parameter_names(
self, module: nn.Module, destination: List[str], prefix: str = ""
) -> List[str]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
module.sparse_grad_parameter_names(destination, prefix)
elif isinstance(module, nn.Embedding):
Expand Down Expand Up @@ -351,6 +364,7 @@ def _state_dict(
prefix: str,
keep_vars: bool,
) -> Dict[str, Any]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
module.state_dict(destination, prefix, keep_vars)
else:
Expand All @@ -376,6 +390,7 @@ def _load_state_dict(
) -> _IncompatibleKeys:
missing_keys = []
unexpected_keys = []
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
return module.load_state_dict(state_dict, strict=strict)
else:
Expand All @@ -398,6 +413,7 @@ def _load_state_dict(
def _named_parameters(
self, module: nn.Module, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
yield from module.named_parameters(prefix, recurse)
else:
Expand All @@ -414,6 +430,7 @@ def named_parameters(

@staticmethod
def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[str]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
yield from module.sharded_parameter_names(prefix)
else:
Expand All @@ -425,6 +442,7 @@ def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[st
def _named_buffers(
self, module: nn.Module, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
yield from module.named_buffers(prefix, recurse)
else:
Expand Down
7 changes: 5 additions & 2 deletions torchrec/distributed/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
from torch.autograd.profiler import record_function
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.fx.node import Node
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule
Expand Down Expand Up @@ -364,8 +365,10 @@ def _rewrite_model( # noqa C901
) -> List[ShardedModule]:

# Get underlying nn.Module
while isinstance(model, DistributedModelParallel) or isinstance(
model, DistributedDataParallel
while (
isinstance(model, DistributedModelParallel)
or isinstance(model, DistributedDataParallel)
or isinstance(model, FullyShardedDataParallel)
):
model = model.module

Expand Down