From 43792bb7d4911cdd363c68313a89209339e8db4e Mon Sep 17 00:00:00 2001 From: Nadav Elyahu Date: Tue, 7 May 2024 10:18:36 +0300 Subject: [PATCH 1/2] z3 scaled_global_grad_norm: repalce get_global_norm with torch.norm --- deepspeed/runtime/zero/stage3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 13ca29c9fceb..b67c3cf4a61f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -15,7 +15,7 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -2027,7 +2027,7 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) + scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups)) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale From 5dd50c32fe3b2b92affccd122c9f29a2105c6924 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu Date: Mon, 27 May 2024 14:53:14 +0300 Subject: [PATCH 2/2] fix grad norm calc in cpu offload and use torch.clip for grad clipping --- deepspeed/runtime/zero/stage3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 7b3dbb35e01c..44a398f72d2b 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1412,7 +1412,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm - return total_norm + return total_norm.cpu() @instrument_w_nvtx def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: @@ -2111,8 +2111,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm): if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)