Skip to content

Commit

Permalink
[zero]fix zero ckptIO with offload (#4529)
Browse files Browse the repository at this point in the history
* fix zero ckptio with offload

* fix load device

* saved tensors in ckpt should be on CPU

* fix unit test

* fix unit test

* add clear cache

* save memory for CI
  • Loading branch information
Gy-Lu committed Sep 1, 2023
1 parent c7b60f7 commit cbac782
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
22 changes: 12 additions & 10 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ def __init__(
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):

# TODO:
# 1. state_dict for checkpoint IO

super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
Expand Down Expand Up @@ -528,9 +525,12 @@ def state_dict(self) -> Dict:
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(gather_tensor, v, group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
gather_tensor = [
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
working_param).cpu()
zero_state[param][k] = param_state

states_dict = self._pack_state(zero_state)
Expand All @@ -553,7 +553,8 @@ def load_state_dict(self, state_dict: Dict):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
device = 'cpu' if self._cpu_offload else 'cuda'
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()

self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()
Expand Down Expand Up @@ -585,9 +586,10 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i

for k, v in states.items():
if isinstance(v, torch.Tensor) and k != 'step':
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v, group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
working_param).cpu()
current_block_size += state_tensor.numel()
current_block[k] = state_tensor

Expand Down
14 changes: 9 additions & 5 deletions tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@
)


# stage 1 and 2 process the optimizer/mode the same way
# only test 2 is fine
@clear_cache_before_run()
@parameterize('stage', [2])
@parameterize('shard', [True, False])
def check_low_level_zero_checkpointIO(stage: int, shard: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32)
@parameterize('offload', [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = HybridAdam((model.parameters()), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
x = torch.randn(1, 3, 224, 224, device='cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
Expand All @@ -50,15 +52,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)

booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)


def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
check_low_level_zero_checkpointIO()
torch.cuda.empty_cache()


@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_low_level_zero_checkpointIO():
spawn(run_dist, 2)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero/test_low_level/test_zero_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
atol = 4e-3

a = a.detach().to(dtype)
b = b.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)

assert_close(a, b, rtol=rtol, atol=atol)

Expand Down

0 comments on commit cbac782

Please sign in to comment.