Skip to content
20 changes: 14 additions & 6 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def wrapper(module, *args, **kwargs):
self._post_init_method(_module)
return _module

wrapper._ds_has_wrapped = True
return wrapper

def post_wrapper_to_empty(f):
Expand All @@ -465,8 +466,10 @@ def wrapper(*args, **kwargs):
return wrapper

def _enable_class_apply(cls):
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)
# avoid re-wrap
if not hasattr(cls._apply, '_ds_has_wrapped'):
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)

def _disable_class_apply(cls):
cls._apply = cls._old_apply_of_skip_init_hook
Expand Down Expand Up @@ -519,15 +522,20 @@ def wrapper(module, *args, **kwargs):
if init_on_meta:
self.skip_init_depth -= 1

wrapper._ds_has_wrapped = True
return wrapper

def _enable_class(cls):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
# avoid re-wrap
if not hasattr(cls.__init__, '_ds_has_wrapped'):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

def _init_subclass(cls, **kwargs):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
# avoid re-wrap
if not hasattr(cls.__init__, '_ds_has_wrapped'):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
Expand Down