Skip to content

Commit

Permalink
[fix] Add option to wrap root module in auto_wrap (#930)
Browse files Browse the repository at this point in the history
* [fix] Add option to wrap root module in auto_wrap

* Fix unit-test comment

* adding a few more tests to make expected behavior clear

* move changes to wrap policy as suggested

* set default to false

* revert pre-commit change

* revert pre-commit change 2

Co-authored-by: Ruan Silva <ruanrms@fb.com>
  • Loading branch information
ruanslv and Ruan Silva committed Feb 15, 2022
1 parent fae2995 commit 3b8f445
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 15 deletions.
35 changes: 27 additions & 8 deletions fairscale/nn/wrap/auto_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ def default_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
# These are customizable for this default policy function.
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
skip_params_check_for_root: bool = False,
) -> bool:
"""Default policy function for :func:`auto_wrap`.
Return if a module should be wrapped during :func:`auto_wrap`.
The first three parameters are used by :func:`auto_wrap`. If
The first four parameters are used by :func:`auto_wrap`. If
you write a custom version of this policy function, your version
needs to at least accept the first three parameters and free
needs to at least accept the first four parameters and free
to do whatever you want in the function.
Args:
Expand All @@ -37,6 +39,8 @@ def default_auto_wrap_policy(
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
module_is_root (bool):
Indicates if current module is the root.
min_num_params (int):
Customizable policy input. It controls the size threshold
Expand All @@ -45,6 +49,9 @@ def default_auto_wrap_policy(
keep as leaves, i.e., their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]):
Customizable set of module types to be excluded in wrapping.
skip_params_check_for_root (bool):
If module_is_root is True, then this includes the root in
wrapping regardless of their number of unwrapped params.
"""
force_leaf_modules = (
default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore
Expand All @@ -63,7 +70,9 @@ def default_auto_wrap_policy(
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, determine if we should wrap.
return is_large and not isinstance(module, tuple(exclude_wrap_modules))
return ((module_is_root and skip_params_check_for_root) or is_large) and not isinstance(
module, tuple(exclude_wrap_modules)
)


# Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported.
Expand All @@ -75,6 +84,7 @@ def config_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
) -> bool:
"""Config based policy function for :func:`auto_wrap`.
Expand All @@ -92,6 +102,9 @@ def config_auto_wrap_policy(
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
Unused by this function.
module_is_root (bool):
Indicates if current module is the root.
Unused by this function.
"""
if recurse:
# We should always recurse.
Expand Down Expand Up @@ -209,7 +222,9 @@ def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **
(default: wrap if > 100M parameters)
"""
if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(module, auto_wrap_policy=auto_wrap_policy, **kwargs)
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
module, auto_wrap_policy=auto_wrap_policy, module_is_root=True, **kwargs
)
return wrapped_module
return module

Expand Down Expand Up @@ -258,7 +273,9 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.disable_autowrap_context()

@staticmethod
def recursive_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable], **kwargs: Any) -> Tuple[nn.Module, int]:
def recursive_wrap(
module: nn.Module, auto_wrap_policy: Optional[Callable], module_is_root: bool, **kwargs: Any
) -> Tuple[nn.Module, int]:
"""
Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`.
Expand All @@ -284,20 +301,22 @@ def recursive_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable], **kw
num_params = sum([p.numel() for p in module.parameters()])

assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params):
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params, module_is_root=module_is_root):
total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary
for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, auto_wrap_policy=auto_wrap_policy, **kwargs
module=child, auto_wrap_policy=auto_wrap_policy, module_is_root=False, **kwargs
)
setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped
total_wrapped_params += num_wrapped_params
# decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params
if auto_wrap_policy(module=module, recurse=False, unwrapped_params=remainder):
if auto_wrap_policy(
module=module, recurse=False, unwrapped_params=remainder, module_is_root=module_is_root
):
# Leaf node or final wrapping of the remainder both happen here.
return wrap(module, **kwargs), num_params
else:
Expand Down
30 changes: 23 additions & 7 deletions tests/nn/wrap/test_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,35 @@ def test_auto_wrap(self):
"""
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
Root is not wrapped given there are not enough unwrapped params left and skip_params_check_for_root
is not set.
"""
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
sequential = nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=60)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], nn.Linear)
assert isinstance(model[1], FSDP)
assert isinstance(model[1].module[0], nn.Linear)
assert isinstance(model[1].module[1], nn.Linear)

def test_auto_wrap_skip_root_checks(self):
"""
Similar test as before but this time we set skip_params_check_for_root=True in the wrap policy.
So in this case the root is wrapped even without enough remaining unwrapped params.
"""
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
my_auto_wrap_policy = functools.partial(
default_auto_wrap_policy, min_num_params=60, skip_params_check_for_root=True
)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.Linear)
assert isinstance(model.module[2], FSDP)
assert isinstance(model.module[2].module[0], nn.Linear)
assert isinstance(model.module[2].module[1], nn.Linear)
assert isinstance(model.module[1], FSDP)
assert isinstance(model.module[1].module[0], nn.Linear)
assert isinstance(model.module[1].module[1], nn.Linear)

def test_auto_wrap_preset_exclude_wrap(self):
"""
Expand Down

0 comments on commit 3b8f445

Please sign in to comment.