From 79fbe5ff06aaa1f97deccc24298e4d0a3a355735 Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Tue, 30 Mar 2021 21:41:28 -0700 Subject: [PATCH 1/4] add gradient predivide factor to FSDP --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 16862cd9c..5ee1648af 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -173,7 +173,9 @@ def __init__( move_grads_to_cpu: Optional[bool] = None, bucket_cap_mb: int = 25, compute_device: Optional[torch.device] = None, + gradient_predivide_factor: Optional[int] = 32, ): + print(f"inside this new FSDP...") super().__init__() self.process_group = process_group or dist.new_group() self.rank = self.process_group.rank() @@ -187,6 +189,7 @@ def __init__( self.buffer_dtype = buffer_dtype or self.compute_dtype self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb + self.gradient_predivide_factor = gradient_predivide_factor self.numel_padded_per_param: List[int] = [] self.compute_device = compute_device @@ -1071,7 +1074,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if self.world_size > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.world_size) + param.grad.data.div_(self.gradient_predivide_factor) callback_fn = functools.partial(self._post_reduction_hook, param) if param._is_sharded: @@ -1098,7 +1101,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> assert torch.cuda.current_stream() == self._streams["post_backward"] assert param.grad is not None self.assert_state(TrainingState.BACKWARD_POST) - param.grad.data = reduced_grad + param.grad.data = reduced_grad.div_(self.world_size/self.gradient_predivide_factor) # Cast grad to param's dtype (typically FP32). Note: we do this # before the move_grads_to_cpu step so that this entire hook remains # non-blocking. The downside is a bit more D2H transfer in that case. From b1b2ca994f839775aa38be01caf1b5aa2a1f12a7 Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Wed, 31 Mar 2021 16:27:18 -0700 Subject: [PATCH 2/4] infer the possible best gradient predivide factor --- .../nn/data_parallel/fully_sharded_data_parallel.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 5ee1648af..2364506fb 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -173,7 +173,6 @@ def __init__( move_grads_to_cpu: Optional[bool] = None, bucket_cap_mb: int = 25, compute_device: Optional[torch.device] = None, - gradient_predivide_factor: Optional[int] = 32, ): print(f"inside this new FSDP...") super().__init__() @@ -189,7 +188,7 @@ def __init__( self.buffer_dtype = buffer_dtype or self.compute_dtype self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb - self.gradient_predivide_factor = gradient_predivide_factor + self.gradient_predivide_factor = self.get_gradient_predivide_factor(self.world_size) self.numel_padded_per_param: List[int] = [] self.compute_device = compute_device @@ -255,6 +254,13 @@ def __init__( # user explicitly requests the local state dict via local_state_dict(). self._return_full_state_dict = True + def get_gradient_predivide_factor(self, world_size: int) -> int: + factor = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor = factor * 2 + print(f"factor = {factor}") + return factor + @property def module(self) -> nn.Module: return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance From 92a655eaef9eb3794e596fa4244ea27e632341e0 Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Wed, 31 Mar 2021 17:15:19 -0700 Subject: [PATCH 3/4] remove debugging st --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 2364506fb..36385b08a 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -174,7 +174,6 @@ def __init__( bucket_cap_mb: int = 25, compute_device: Optional[torch.device] = None, ): - print(f"inside this new FSDP...") super().__init__() self.process_group = process_group or dist.new_group() self.rank = self.process_group.rank() @@ -258,7 +257,6 @@ def get_gradient_predivide_factor(self, world_size: int) -> int: factor = 1 while world_size % factor == 0 and world_size / factor > factor: factor = factor * 2 - print(f"factor = {factor}") return factor @property From 1c28c3c46a25c19218a820aba33414e5b5e839f4 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 3 Apr 2021 06:26:14 -0700 Subject: [PATCH 4/4] Fix lint, CR comments --- .../nn/data_parallel/fully_sharded_data_parallel.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 36385b08a..73f359d9e 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -187,7 +187,8 @@ def __init__( self.buffer_dtype = buffer_dtype or self.compute_dtype self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb - self.gradient_predivide_factor = self.get_gradient_predivide_factor(self.world_size) + self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size) + self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.numel_padded_per_param: List[int] = [] self.compute_device = compute_device @@ -1076,7 +1077,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Cast grad to FP32. param.grad.data = param.grad.data.to(param.dtype) - if self.world_size > 1: + if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. param.grad.data.div_(self.gradient_predivide_factor) @@ -1105,7 +1106,10 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> assert torch.cuda.current_stream() == self._streams["post_backward"] assert param.grad is not None self.assert_state(TrainingState.BACKWARD_POST) - param.grad.data = reduced_grad.div_(self.world_size/self.gradient_predivide_factor) + param.grad.data = reduced_grad + if self.gradient_postdivide_factor > 1: + # Average grad by world_size for consistency with PyTorch DDP. + param.grad.data.div_(self.gradient_postdivide_factor) # Cast grad to param's dtype (typically FP32). Note: we do this # before the move_grads_to_cpu step so that this entire hook remains # non-blocking. The downside is a bit more D2H transfer in that case.