diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 66270cfcb..bd916c6f5 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -136,7 +136,7 @@ def step(self, closure: Any = None) -> None: self._step_num += 1 @torch.no_grad() - def clip_grad_norm_(self) -> None: + def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: """Clip the gradient norm of all parameters.""" max_norm = self._max_gradient norm_type = float(self._norm_type) @@ -224,6 +224,7 @@ def clip_grad_norm_(self) -> None: clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6)) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) torch._foreach_mul_(all_grads, clip_coef_clamped) + return total_grad_norm def _batch_cal_norm(