diff --git a/train.py b/train.py index c7d27e2a3f..6c5a51a29c 100644 --- a/train.py +++ b/train.py @@ -112,6 +112,7 @@ def train_model( optimizer.zero_grad(set_to_none=True) grad_scaler.scale(loss).backward() + grad_scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) grad_scaler.step(optimizer) grad_scaler.update()