Skip to content

Commit

Permalink
[minor] help pure fp16 FSDP init a bit (#1068)
Browse files Browse the repository at this point in the history
* [minor] [FSDP] add a better for pure fp16

* [minor] [wrap] add a flag to help fsdp pure fp16 wrapping
  • Loading branch information
min-xu-ai committed Sep 10, 2022
1 parent 454537d commit 73bf596
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
6 changes: 6 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,12 @@ def __init__(
self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params)
self._param_name_groups = param_name_groups

# Check to see if the mixed precision setting is correct.
if self.compute_dtype is torch.float16 and self.mixed_precision is False:
for p in self.params:
if p.dtype is not torch.float16:
raise ValueError("Expecting FP16 param type in pure FP16 mode.")

# Shard module parameters in place
self._shard_parameters_()

Expand Down
7 changes: 7 additions & 0 deletions fairscale/nn/wrap/auto_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
assert isinstance(module_overrides, dict)
wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides}
assert ConfigAutoWrap.wrapper_cls is not None
if ConfigAutoWrap.move_module_cuda_half:
module = module.cuda().half()
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return module

Expand Down Expand Up @@ -236,6 +238,7 @@ class ConfigAutoWrap:
"""

in_autowrap_context: bool = False # Context flag
move_module_cuda_half: bool = False # A flag to control the wrap() function.
wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} # Wrapper's args
auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap
Expand All @@ -252,6 +255,9 @@ def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -
)
ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context.
if "move_module_cuda_half" in kwargs.keys():
ConfigAutoWrap.move_module_cuda_half = cast(bool, kwargs["move_module_cuda_half"])
del kwargs["move_module_cuda_half"]
assert "wrapper_cls" in kwargs.keys()
ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"]
Expand All @@ -262,6 +268,7 @@ def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -
@staticmethod
def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.move_module_cuda_half = False
ConfigAutoWrap.wrapper_cls = None
ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.auto_wrap_policy = None
Expand Down

0 comments on commit 73bf596

Please sign in to comment.