Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,7 @@ def __exit__(self, *exc):
if not self.enabled:
return
if self.src_rank is None:
self.params[0].partition(param_list=self.params, has_been_updated=False)
return

handles = [
Expand Down
52 changes: 30 additions & 22 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,26 @@ def forward(self, x, y):
# make sure all sides saved it
dist.barrier()

if zero_stage == 3:
with deepspeed.zero.GatheredParameters(list(
model.module.parameters(recurse=True)),
modifier_rank=None):
pass # this forces gathering the model

#dump_state_dict(model)

orig_state_dict = {}
for name, param in model.module.named_parameters():
orig_state_dict[name] = param.detach().cpu()
if zero_stage == 3:
with deepspeed.zero.GatheredParameters(param, modifier_rank=None):
orig_state_dict[name] = param.detach().cpu()
else:
orig_state_dict[name] = param.detach().cpu()

if dist.get_rank() == 0:
if zero_stage == 3:
with deepspeed.zero.GatheredParameters(model.parameters(),
modifier_rank=None):
fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
fp32_state_dict = fp32_model.state_dict()
else:
fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
#dump_state_dict(fp32_model)

fp32_state_dict = fp32_model.state_dict()

#dump_state_dict(fp32_model)

if dist.get_rank() == 0:
for name in orig_state_dict.keys():
# float() workaround for torch<1.6
assert torch.allclose(orig_state_dict[name].float(),
Expand Down Expand Up @@ -308,23 +311,28 @@ def forward(self, x, y):
# make sure all sides saved it
dist.barrier()

if zero_stage == 3:
with deepspeed.zero.GatheredParameters(list(
model.module.parameters(recurse=True)),
modifier_rank=None):
pass # this forces gathering the model

#dump_state_dict(model)

orig_state_dict = {}
for name, param in model.module.named_parameters():
orig_state_dict[name] = param.detach().cpu()
if zero_stage == 3:
with deepspeed.zero.GatheredParameters(param, modifier_rank=None):
orig_state_dict[name] = param.detach().cpu()
else:
orig_state_dict[name] = param.detach().cpu()

if dist.get_rank() == 0:
if zero_stage == 3:
with deepspeed.zero.GatheredParameters(model.parameters(),
modifier_rank=None):
fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
fp32_state_dict = fp32_model.state_dict()
else:
fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
#dump_state_dict(fp32_model)

fp32_state_dict = fp32_model.state_dict()

#dump_state_dict(fp32_model)

if dist.get_rank() == 0:
for name in orig_state_dict.keys():
# float() workaround for torch<1.6
assert torch.allclose(orig_state_dict[name].float(),
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/runtime/zero/test_zero_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ def forward(self, x):
}


class TestZeroGatheredParametersFree(DistributedTest):
world_size = 1

def test(self):
config_dict = {"train_batch_size": 1, "zero_optimization": {"stage": 3}}
hidden_dim = 10

class MyModel(torch.nn.Module):
def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)

with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = MyModel(hidden_dim)

with deepspeed.zero.GatheredParameters(list(model.parameters())):
assert model.l1.weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor"

# on exit from `GatheredParameters` the gathered params should be freed and not leak memory
assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized"


class TestSerialContext(DistributedTest):
world_size = 1
init_distributed = False
Expand Down