Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestion: Adding more stats (e.g., grad_norm, loss_scale) to the training log #5134

Closed
Emrys365 opened this issue Apr 21, 2023 · 0 comments

Comments

@Emrys365
Copy link
Collaborator

Emrys365 commented Apr 21, 2023

Hi, I think the current training log is not very informative when we obtain some weird training curves, such as a sudden loss increase. It would be very helpful if some general statistics can be added to the training log:

  • grad_norm: showing the average gradient norm since last report
  • clip: showing the percentage of batches that have been grad-clipped since last report
  • loss_scale: showing the average loss scale since last report (when scaler is not None)
  • etc.

By referring to the implementation in fairseq, I think the above can be easily implemented by modifying espnet2/train/trainer.py as follows:

--- a/espnet2/train/trainer.py
+++ b/espnet2/train/trainer.py

@@ -678,6 +678,17 @@ class Trainer:
                             scaler.update()
 
                 else:
+                    reporter.register(
+                        {
+                            "grad_norm": grad_norm,
+                            "clip": torch.where(
+                                grad_norm > grad_clip,
+                                grad_norm.new_tensor(100),
+                                grad_norm.new_tensor(0),
+                            ),
+                            "loss_scale": scaler.get_scale() if scaler else 1.0,
+                        }
+                    )
                     all_steps_are_invalid = False
                     with reporter.measure_time("optim_step_time"):
                         for iopt, (optimizer, scheduler) in enumerate(

With the above modifications, the training log of a speech enhancement task will look like this:

2023-04-22 13:52:12,339 (trainer:732) INFO: 8epoch:train:2901-3000batch: iter_time=0.070, forward_time=0.360,
si_snr_loss=-17.036, loss=-17.036, backward_time=0.294, grad_norm=37.450, clip=100.000, loss_scale=1.000,
optim_step_time=0.022, optim0_lr0=3.765e-04, train_time=1.137
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant