Skip to content

Commit

Permalink
[FSDP] Upstream fairseq big changes (#956)
Browse files Browse the repository at this point in the history
* made gradient predivide factor configurable

* fix lints

Co-authored-by: Your Name <you@example.com>
  • Loading branch information
m3rlin45 and Your Name committed Mar 16, 2022
1 parent 3c24beb commit 1bc96fa
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def __init__(
cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None,
):
try:
import torch._C
Expand Down Expand Up @@ -399,7 +400,9 @@ def __init__(
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False

self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size
)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor

self.numel_padded_per_param: List[int] = []
Expand Down

0 comments on commit 1bc96fa

Please sign in to comment.