diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 6ef87f9e00aa..34f10ec56210 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -228,10 +228,34 @@ def __enter__(self): def partition_after(f): @functools.wraps(f) def wrapper(module, *args, **kwargs): + + # important logic: We want to run post_init only after child's __init__ is + # completed, and do nothing after __init__ of any of its parents and grandparents in + # the inheritance ancestry. This way the partitioning will need to happen only once + # when the whole object is ready to be partitioned and not before. This is because + # often the child module will need to tweak the weights - for example running a + # custom weights init function. So if a parent created the weights param, the child + # won't need to gather it in order to tweak it + print_rank_0(f'Before initializing {module.__class__.__name__}', force=False) + + is_child_module = False + if not hasattr(module, "_ds_child_entered"): + # child's __init__ was called, since parents all see the same object they can now skip post_init + is_child_module = True + setattr(module, "_ds_child_entered", True) + f(module, *args, **kwargs) - self._post_init_method(module) + + if is_child_module: + # child's __init__ is done, now we can run a single post_init on the child object + delattr(module, "_ds_child_entered") + + print_rank_0(f'Running post_init for {module.__class__.__name__}', + force=False) + self._post_init_method(module) + print_rank_0( f'After initializing followed by post init for {module.__class__.__name__}', force=False) diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index 98ee0c7ad00b..edeeb1c6e960 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -283,6 +283,7 @@ def test_stage_3_output_type(output_type): engine.step() +# test that no sub-class or super-class is missed class ConvX(torch.nn.Conv1d): def __init__(self, *args): super().__init__(*args) @@ -310,3 +311,52 @@ def test_subclass_param(): assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE assert model.conv1.param_in.ds_status == ZeroParamStatus.NOT_AVAILABLE + + +# test that sub-classes get params that aren't prematurely partitioned and thus requiring gathering +# fixed by https://github.com/microsoft/DeepSpeed/pull/1202 +class GrandPa(torch.nn.Module): + def __init__(self, *args): + super().__init__(*args) + self.param_grandpa = torch.nn.Parameter(torch.ones(5)) + self.param_grandpa.data = (self.param_grandpa.data + + 1).data # test param is not yet partitioned + + +class Pa(GrandPa): + def __init__(self, *args): + super().__init__(*args) + self.param_pa = torch.nn.Parameter(torch.ones(5)) + self.param_pa.data = (self.param_pa.data + + 1).data # test param is not yet partitioned + self.param_grandpa.data = (self.param_grandpa.data + + 1).data # test param is not yet partitioned + + +class Son(Pa): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.ones(5)) + self.param.data = (self.param.data + 1).data # test param is not yet partitioned + self.param_pa.data = (self.param_pa.data + + 1).data # test param is not yet partitioned + self.param_grandpa.data = (self.param_grandpa.data + + 1).data # test param is not yet partitioned + + +def test_subclass_param_init(): + setup_serial_env() + with deepspeed.zero.Init(config=config): + model = Son().cpu() + + # test that all params have been partitioned + assert model.param_grandpa.ds_status == ZeroParamStatus.NOT_AVAILABLE + assert model.param_pa.ds_status == ZeroParamStatus.NOT_AVAILABLE + assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE + + # test that the weights manipulation during each __init__ worked in all w/o needing gathering + ones = torch.ones(5).half().cuda() + with deepspeed.zero.GatheredParameters(list(model.parameters(recurse=False))): + assert torch.equal(model.param, ones + 1) + assert torch.equal(model.param_pa, ones + 2) + assert torch.equal(model.param_grandpa, ones + 3)