From 813393e785d5af03cbd678c87dce43abbc4d50d1 Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Thu, 10 Feb 2022 10:46:39 -0800 Subject: [PATCH] enable FSDP Summary: As title. Reviewed By: dstaay-fb, bigrabithong, liangluofb Differential Revision: D33712372 fbshipit-source-id: 846385abe6106bd2e93ba6797c1ea8caf16307e6 --- torchrec/distributed/model_parallel.py | 20 +++++++++++++++++++- torchrec/distributed/train_pipeline.py | 7 +++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 9cac3d02b..75bdebe10 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -12,6 +12,7 @@ import torch import torch.distributed as dist from torch import nn +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embeddingbag import ( @@ -100,6 +101,14 @@ def wrap( dmp.module._set_static_graph() +def _strip_DDP(module: nn.Module) -> nn.Module: + if isinstance(module, FullyShardedDataParallel) or isinstance( + module, DistributedDataParallel + ): + module = module.module + return module + + class DistributedModelParallel(nn.Module, FusedOptimizerModule): """ Entry point to model parallelism. @@ -208,6 +217,7 @@ def dmp_module(self) -> nn.Module: return ( self.module.module if isinstance(self.module, DistributedDataParallel) + or isinstance(self.module, FullyShardedDataParallel) else self.module ) @@ -220,7 +230,9 @@ def init_data_parallel(self) -> None: See init_data_parallel c-tor argument for usage. It's safe to call this method multiple times. """ - if not isinstance(self.module, DistributedDataParallel): + if not isinstance(self.module, DistributedDataParallel) and not isinstance( + self.module, FullyShardedDataParallel + ): # Allocate any 'meta' tensors if self.init_parameters: self._init_parameters(self.module) @@ -316,6 +328,7 @@ def sparse_grad_parameter_names( def _sparse_grad_parameter_names( self, module: nn.Module, destination: List[str], prefix: str = "" ) -> List[str]: + module = _strip_DDP(module) if isinstance(module, ShardedModule): module.sparse_grad_parameter_names(destination, prefix) elif isinstance(module, nn.Embedding): @@ -351,6 +364,7 @@ def _state_dict( prefix: str, keep_vars: bool, ) -> Dict[str, Any]: + module = _strip_DDP(module) if isinstance(module, ShardedModule): module.state_dict(destination, prefix, keep_vars) else: @@ -376,6 +390,7 @@ def _load_state_dict( ) -> _IncompatibleKeys: missing_keys = [] unexpected_keys = [] + module = _strip_DDP(module) if isinstance(module, ShardedModule): return module.load_state_dict(state_dict, strict=strict) else: @@ -398,6 +413,7 @@ def _load_state_dict( def _named_parameters( self, module: nn.Module, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + module = _strip_DDP(module) if isinstance(module, ShardedModule): yield from module.named_parameters(prefix, recurse) else: @@ -414,6 +430,7 @@ def named_parameters( @staticmethod def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[str]: + module = _strip_DDP(module) if isinstance(module, ShardedModule): yield from module.sharded_parameter_names(prefix) else: @@ -425,6 +442,7 @@ def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[st def _named_buffers( self, module: nn.Module, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, torch.Tensor]]: + module = _strip_DDP(module) if isinstance(module, ShardedModule): yield from module.named_buffers(prefix, recurse) else: diff --git a/torchrec/distributed/train_pipeline.py b/torchrec/distributed/train_pipeline.py index 324747e12..8e92cecf0 100644 --- a/torchrec/distributed/train_pipeline.py +++ b/torchrec/distributed/train_pipeline.py @@ -23,6 +23,7 @@ import torch from torch.autograd.profiler import record_function +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.fx.node import Node from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule @@ -364,8 +365,10 @@ def _rewrite_model( # noqa C901 ) -> List[ShardedModule]: # Get underlying nn.Module - while isinstance(model, DistributedModelParallel) or isinstance( - model, DistributedDataParallel + while ( + isinstance(model, DistributedModelParallel) + or isinstance(model, DistributedDataParallel) + or isinstance(model, FullyShardedDataParallel) ): model = model.module