From 52b4f14e91cb2e82ba62a4f39c7030d209ca8277 Mon Sep 17 00:00:00 2001 From: Yassine Date: Fri, 8 Dec 2023 11:14:51 +0100 Subject: [PATCH] Unscale gradients before clipping --- train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train.py b/train.py index c7d27e2a3..6c5a51a29 100644 --- a/train.py +++ b/train.py @@ -112,6 +112,7 @@ def train_model( optimizer.zero_grad(set_to_none=True) grad_scaler.scale(loss).backward() + grad_scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) grad_scaler.step(optimizer) grad_scaler.update()