From 75540d230f58d93c0b1e3bc2a0db1fc853f12bc2 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" Date: Tue, 22 Oct 2024 11:23:32 -0700 Subject: [PATCH] return total grad norm in torchrec grad clipping Summary: this is to keep consistent with torch.nn.utils.clip_grad_norm_ Differential Revision: D64712277 --- torchrec/optim/clipping.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 66270cfcb..bd916c6f5 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -136,7 +136,7 @@ def step(self, closure: Any = None) -> None: self._step_num += 1 @torch.no_grad() - def clip_grad_norm_(self) -> None: + def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: """Clip the gradient norm of all parameters.""" max_norm = self._max_gradient norm_type = float(self._norm_type) @@ -224,6 +224,7 @@ def clip_grad_norm_(self) -> None: clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6)) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) torch._foreach_mul_(all_grads, clip_coef_clamped) + return total_grad_norm def _batch_cal_norm(