diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 01468d87e42a..2249678c190f 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3414,7 +3414,7 @@ def _load_checkpoint(self, if self.optimizer is not None and hasattr(self.optimizer, 'refresh_fp32_params'): self.optimizer.refresh_fp32_params() else: - has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() + has_zero_optimizer_state = self.zero_optimization() if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state: if self.has_moe_layers: largest_group_name = groups._get_max_expert_size_name() @@ -3883,7 +3883,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parame save_path = self._get_ckpt_name(save_dir, tag) - zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() + zero_optimizer_state = self.zero_optimization() save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index cca92348fc62..9ad785071602 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -45,7 +45,8 @@ def test_pipeline_checkpoint_loading(self, tmpdir, zero_stage): checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True) - @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [(1, False, 'Adam'), (2, False, 'Adam'), + @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [(0, False, 'Adam'), (1, False, 'Adam'), + (2, False, 'Adam'), (2, True, 'deepspeed_adam'), (3, False, 'Adam'), (3, True, 'deepspeed_adam')])