Skip to content

Commit

Permalink
[checkpointio] optimize zero optim checkpoint io (#4591)
Browse files Browse the repository at this point in the history
* [zero] update checkpoint io to save memory

* [checkpointio] add device map to save memory
  • Loading branch information
ver217 committed Sep 4, 2023
1 parent cfa6070 commit 63ecafb
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 22 deletions.
51 changes: 38 additions & 13 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -126,19 +131,39 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
super().load_sharded_optimizer(optimizer, index_file_path, prefix)
current_rank_state_dict = optimizer.optim.state_dict()['state']
for param_idx, state in current_rank_state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self.coordinator.world_size -
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)

# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory.')
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)

checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self.coordinator.world_size -
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
load_states_into_optimizer(optimizer, state_dict, id_map)

sharded_optimizer_loading_epilogue(optimizer)


class LowLevelZeroModel(ModelWrapper):
Expand Down
2 changes: 0 additions & 2 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
del state_dict
gc.collect()

sharded_optimizer_loading_epilogue(optimizer)

Expand Down
6 changes: 3 additions & 3 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)
return torch.load(checkpoint_file, map_location=torch.device('cpu'))


def load_state_dict_into_model(model: nn.Module,
Expand Down Expand Up @@ -297,7 +297,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str

# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path)
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')

Expand Down Expand Up @@ -608,7 +608,7 @@ def load_state_dict(checkpoint_file_path: Path):

else:
# load with torch
return torch.load(checkpoint_file_path)
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))


def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
Expand Down
6 changes: 2 additions & 4 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _add_to_bucket(self, param, group_id):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
group_id != self._bucket_store.current_group_id:
group_id != self._bucket_store.current_group_id:
self._run_reduction()

padding_size = self._param_store.get_param_padding_size(param)
Expand Down Expand Up @@ -553,11 +553,9 @@ def load_state_dict(self, state_dict: Dict):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
device = 'cpu' if self._cpu_offload else 'cuda'
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()

self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()

def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Expand Down

0 comments on commit 63ecafb

Please sign in to comment.