diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9b7645261eae..b0a3ab778f2a 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -15,7 +15,7 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -1413,7 +1413,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm - return total_norm + return total_norm.cpu() @instrument_w_nvtx def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: @@ -2028,7 +2028,7 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) + scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups)) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale @@ -2112,8 +2112,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm): if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale) diff --git a/tests/unit/runtime/zero/test_zero_offloadpp.py b/tests/unit/runtime/zero/test_zero_offloadpp.py index 5bfec399e19f..8ae99e2237e2 100644 --- a/tests/unit/runtime/zero/test_zero_offloadpp.py +++ b/tests/unit/runtime/zero/test_zero_offloadpp.py @@ -43,6 +43,7 @@ def test(self, h_dim: int, n_layers: int) -> None: config_dict = { "train_batch_size": 256, "steps_per_print": 1, + "gradient_clipping": 1.0, "optimizer": { "type": "Adam", "params": {