Skip to content

Commit

Permalink
support unsharded loading for optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Sep 21, 2023
1 parent a16056e commit de7de69
Showing 1 changed file with 58 additions and 4 deletions.
62 changes: 58 additions & 4 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer state dict to a checkpoint file with given path.
Save optimizer state dict to a file with given path.
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
Expand All @@ -663,6 +663,7 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,

assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"

# optimizer states of parameters kept by local device('s pipeline stage)
local_states = dict()

for param, state in optimizer.optim.state.items():
Expand Down Expand Up @@ -708,11 +709,64 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from a file with given path.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the checkpoint file.
"""

def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]

if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
raise NotImplementedError

assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"

# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)

# Load param_groups.
updated_groups = []
saved_groups = state_dict["param_groups"]
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})

# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
master_to_working_map = optimizer.get_master_to_working_map()
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)

# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state

sharded_optimizer_loading_epilogue(optimizer.optim)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Expand Down

0 comments on commit de7de69

Please sign in to comment.