Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,34 @@ def __enter__(self):
def partition_after(f):
@functools.wraps(f)
def wrapper(module, *args, **kwargs):

# important logic: We want to run post_init only after child's __init__ is
# completed, and do nothing after __init__ of any of its parents and grandparents in
# the inheritance ancestry. This way the partitioning will need to happen only once
# when the whole object is ready to be partitioned and not before. This is because
# often the child module will need to tweak the weights - for example running a
# custom weights init function. So if a parent created the weights param, the child
# won't need to gather it in order to tweak it

print_rank_0(f'Before initializing {module.__class__.__name__}',
force=False)

is_child_module = False
if not hasattr(module, "_ds_child_entered"):
# child's __init__ was called, since parents all see the same object they can now skip post_init
is_child_module = True
setattr(module, "_ds_child_entered", True)

f(module, *args, **kwargs)
self._post_init_method(module)

if is_child_module:
# child's __init__ is done, now we can run a single post_init on the child object
delattr(module, "_ds_child_entered")

print_rank_0(f'Running post_init for {module.__class__.__name__}',
force=False)
self._post_init_method(module)

print_rank_0(
f'After initializing followed by post init for {module.__class__.__name__}',
force=False)
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/test_zero_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def test_stage_3_output_type(output_type):
engine.step()


# test that no sub-class or super-class is missed
class ConvX(torch.nn.Conv1d):
def __init__(self, *args):
super().__init__(*args)
Expand Down Expand Up @@ -310,3 +311,52 @@ def test_subclass_param():

assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE
assert model.conv1.param_in.ds_status == ZeroParamStatus.NOT_AVAILABLE


# test that sub-classes get params that aren't prematurely partitioned and thus requiring gathering
# fixed by https://github.com/microsoft/DeepSpeed/pull/1202
class GrandPa(torch.nn.Module):
def __init__(self, *args):
super().__init__(*args)
self.param_grandpa = torch.nn.Parameter(torch.ones(5))
self.param_grandpa.data = (self.param_grandpa.data +
1).data # test param is not yet partitioned


class Pa(GrandPa):
def __init__(self, *args):
super().__init__(*args)
self.param_pa = torch.nn.Parameter(torch.ones(5))
self.param_pa.data = (self.param_pa.data +
1).data # test param is not yet partitioned
self.param_grandpa.data = (self.param_grandpa.data +
1).data # test param is not yet partitioned


class Son(Pa):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.ones(5))
self.param.data = (self.param.data + 1).data # test param is not yet partitioned
self.param_pa.data = (self.param_pa.data +
1).data # test param is not yet partitioned
self.param_grandpa.data = (self.param_grandpa.data +
1).data # test param is not yet partitioned


def test_subclass_param_init():
setup_serial_env()
with deepspeed.zero.Init(config=config):
model = Son().cpu()

# test that all params have been partitioned
assert model.param_grandpa.ds_status == ZeroParamStatus.NOT_AVAILABLE
assert model.param_pa.ds_status == ZeroParamStatus.NOT_AVAILABLE
assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE

# test that the weights manipulation during each __init__ worked in all w/o needing gathering
ones = torch.ones(5).half().cuda()
with deepspeed.zero.GatheredParameters(list(model.parameters(recurse=False))):
assert torch.equal(model.param, ones + 1)
assert torch.equal(model.param_pa, ones + 2)
assert torch.equal(model.param_grandpa, ones + 3)