diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 94bba1d61a22..6d1538e17b6e 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1637,6 +1637,7 @@ def __exit__(self, *exc): if not self.enabled: return if self.src_rank is None: + self.params[0].partition(param_list=self.params, has_been_updated=False) return handles = [ diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index b3ce68806444..3b3858b7d5a1 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -209,23 +209,26 @@ def forward(self, x, y): # make sure all sides saved it dist.barrier() - if zero_stage == 3: - with deepspeed.zero.GatheredParameters(list( - model.module.parameters(recurse=True)), - modifier_rank=None): - pass # this forces gathering the model - - #dump_state_dict(model) - orig_state_dict = {} for name, param in model.module.named_parameters(): - orig_state_dict[name] = param.detach().cpu() + if zero_stage == 3: + with deepspeed.zero.GatheredParameters(param, modifier_rank=None): + orig_state_dict[name] = param.detach().cpu() + else: + orig_state_dict[name] = param.detach().cpu() - if dist.get_rank() == 0: + if zero_stage == 3: + with deepspeed.zero.GatheredParameters(model.parameters(), + modifier_rank=None): + fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) + fp32_state_dict = fp32_model.state_dict() + else: fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) - #dump_state_dict(fp32_model) - fp32_state_dict = fp32_model.state_dict() + + #dump_state_dict(fp32_model) + + if dist.get_rank() == 0: for name in orig_state_dict.keys(): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), @@ -308,23 +311,28 @@ def forward(self, x, y): # make sure all sides saved it dist.barrier() - if zero_stage == 3: - with deepspeed.zero.GatheredParameters(list( - model.module.parameters(recurse=True)), - modifier_rank=None): - pass # this forces gathering the model - #dump_state_dict(model) orig_state_dict = {} for name, param in model.module.named_parameters(): - orig_state_dict[name] = param.detach().cpu() + if zero_stage == 3: + with deepspeed.zero.GatheredParameters(param, modifier_rank=None): + orig_state_dict[name] = param.detach().cpu() + else: + orig_state_dict[name] = param.detach().cpu() - if dist.get_rank() == 0: + if zero_stage == 3: + with deepspeed.zero.GatheredParameters(model.parameters(), + modifier_rank=None): + fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) + fp32_state_dict = fp32_model.state_dict() + else: fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) - #dump_state_dict(fp32_model) - fp32_state_dict = fp32_model.state_dict() + + #dump_state_dict(fp32_model) + + if dist.get_rank() == 0: for name in orig_state_dict.keys(): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index 130d446d8792..c349e03028a4 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -51,6 +51,28 @@ def forward(self, x): } +class TestZeroGatheredParametersFree(DistributedTest): + world_size = 1 + + def test(self): + config_dict = {"train_batch_size": 1, "zero_optimization": {"stage": 3}} + hidden_dim = 10 + + class MyModel(torch.nn.Module): + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = torch.nn.Linear(hidden_dim, hidden_dim) + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = MyModel(hidden_dim) + + with deepspeed.zero.GatheredParameters(list(model.parameters())): + assert model.l1.weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor" + + # on exit from `GatheredParameters` the gathered params should be freed and not leak memory + assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized" + + class TestSerialContext(DistributedTest): world_size = 1 init_distributed = False