Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Sep 25, 2023
1 parent 0a6e09b commit 49f53a3
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def _get_param_id_from_optimizer_param(
saved_groups = state_dict["param_groups"]
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"]
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})

Expand All @@ -755,6 +755,8 @@ def _get_param_id_from_optimizer_param(

# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
Expand Down

0 comments on commit 49f53a3

Please sign in to comment.