Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3414,7 +3414,7 @@ def _load_checkpoint(self,
if self.optimizer is not None and hasattr(self.optimizer, 'refresh_fp32_params'):
self.optimizer.refresh_fp32_params()
else:
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
has_zero_optimizer_state = self.zero_optimization()
if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:
if self.has_moe_layers:
largest_group_name = groups._get_max_expert_size_name()
Expand Down Expand Up @@ -3883,7 +3883,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parame

save_path = self._get_ckpt_name(save_dir, tag)

zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
zero_optimizer_state = self.zero_optimization()

save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def test_pipeline_checkpoint_loading(self, tmpdir, zero_stage):

checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True)

@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [(1, False, 'Adam'), (2, False, 'Adam'),
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [(0, False, 'Adam'), (1, False, 'Adam'),
(2, False, 'Adam'),
(2, True, 'deepspeed_adam'),
(3, False, 'Adam'),
(3, True, 'deepspeed_adam')])
Expand Down