Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train loss is nan or inf #10

Closed
Butterfly-c opened this issue Jun 30, 2022 · 29 comments · Fixed by #11
Closed

Train loss is nan or inf #10

Butterfly-c opened this issue Jun 30, 2022 · 29 comments · Fixed by #11

Comments

@Butterfly-c
Copy link

Butterfly-c commented Jun 30, 2022

After using the fast_rnnt loss in my environment, the trainning loss always failed into nan or inf.
The configuration fo my ConformerTransducer enviroment is as follows:

  • v100-32g-4gpu * 2
  • platform: fairseq
  • max_tokens: 5000 and update_freq: 13 (ie. batch_size 5000 * 13 * 8)
  • warmup_lr 1e-7 & lr: 1e-4 & lr_scheduler inverse_sqrt & warmup_updates is 8000
    -optimizer adam
    -pruned_loss_scaled = 0 if num_updates <= 10000
    pruned_loss_scaled = 0.1 if 10000 < num_updates <= 20000
    pruned_loss_scaled = 1 if num_updates > 20000

Finally, 6k hours training data are used to train the RNNT model. At the warmup stage (i.e.pruned_loss_scaled = 0 ), the loss always failed into nan,Also when pruned_loss_scaled is set to 0.1 , the loss always failed into inf.

Is there any suggestions to solve this problem?

@pkufool
Copy link
Contributor

pkufool commented Jun 30, 2022

When did it turn into nan or inf? At the beginning of the training or middle of training, could you please upload the training log here, thanks!

@Butterfly-c
Copy link
Author

When did it turn into nan or inf? At the beginning of the training or middle of training, could you please upload the training log here, thanks!

At the beginning of the training (pruned_loss_scaled = 0) the loss trun into nan. After 10000 num_updates, the pruned_loss_scaled was set as 0.1 and the loss turn into inf.

Soryy, something went wrong when I upload the log.

@pkufool
Copy link
Contributor

pkufool commented Jun 30, 2022

Do you have any sequences that U > T, I mean the number of tokens in transcript is greater than the number of frames.

@Butterfly-c
Copy link
Author

Do you have any sequences that U > T, I mean the number of tokens in transcript is greater than the number of frames.

The sample rate is 4 depends on 2 maxpooling lalyers. So the tokens U in unlikely to be greater than T.

I put some logs here:

epoch 3 ; loss inf; num updates 16100 ; lr 0.000704907
epoch 3 ; loss 1.13339; num updates 16200 ; lr 0.000702728
epoch 3 ; loss 1.13215; num updates 16300 ; lr 0.000700569
epoch 3 ; loss inf; num updates 16400 ; lr 0.000698043

@danpovey
Copy link
Collaborator

What iteration did the loss become inf on, and what kind of model were you using?

@Butterfly-c
Copy link
Author

Butterfly-c commented Jun 30, 2022

What iteration did the loss become inf on, and what kind of model were you using?

The loss become inf at epoch 2, where the pruned_loss_scaled is set to 0.1

The ConformerTransducer model is configured as follows:
Encoder: 2 vggblock + 12 conformer and + 1 lstmp + 1layrenorm
Decoder: 2 lstm + droupout
Joiner: is condigured as k2

@Butterfly-c
Copy link
Author

What iteration did the loss become inf on, and what kind of model were you using?

Other configurations of the joiner is as follows:
lm_only_scale = 0.25
am_only_scale = 0
prune_range = 4
simple_loss_scale= 0.5

pruned_loss_scaled = 0 if num_updates <= 10000
pruned_loss_scaled = 0.1 if 10000 < num_updates <= 20000
pruned_loss_scaled = 1 if num_updates > 20000

@pkufool
Copy link
Contributor

pkufool commented Jun 30, 2022

Can you dump the input of the batches that leads to the inf loss, so we can use it to debug this issue. Thanks.

@danpovey
Copy link
Collaborator

danpovey commented Jul 3, 2022

@pkufool perhaps it was not obvious to him how to do this?
Also, @Butterfly-c , are you using fp16 / half-precision for training? It can be tricky to tune a network to perform OK with fp16.
One possibility is to detect inf in the loss , e.g. by comparing (loss - loss) to 0, and skip the update and print a warning.
If you have any utterances in your training set that have too-long transcripts for the utterance length, those could lead to inf loss. It's possible that the model is training OK, if the individual losses on most batches stay finite, even though the overall loss may be infinite. Cases with too-long transcripts will generate infinite loss but will not generate infinite gradients.

@pkufool
Copy link
Contributor

pkufool commented Jul 3, 2022

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")

@Butterfly-c
Copy link
Author

@pkufool perhaps it was not obvious to him how to do this? Also, @Butterfly-c , are you using fp16 / half-precision for training? It can be tricky to tune a network to perform OK with fp16. One possibility is to detect inf in the loss , e.g. by comparing (loss - loss) to 0, and skip the update and print a warning. If you have any utterances in your training set that have too-long transcripts for the utterance length, those could lead to inf loss. It's possible that the model is training OK, if the individual losses on most batches stay finite, even though the overall loss may be infinite. Cases with too-long transcripts will generate infinite loss but will not generate infinite gradients.

Thanks for your kindly reply!
I have decoded one model from epoch 4, and the decoding result is ok. But, I'm still confused with the inf loss.
The max_frame is set to 2500 (i.e. 25s ) in my training environment.
I'm curious how long the sentence is can be defined as too-long transcripts?

@Butterfly-c
Copy link
Author

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")

Thanks for your suggestion, I'm trying to upload the pruned_bad_case.pt for you to debug the inf issue. It'll take me some time.

@Butterfly-c
Copy link
Author

We have compared two models trained with the warp-transducer and the fast-rnnt seperately,but the The GPU usage does not decrease significantly.

Intuitively, the training time of the two models is as follows:
loss times_per_update
warp-transducer 7m40s
fast-rnnt 6m40s

The models above are both tained with v100-32G-4gpu * 2 (i.e. 8gpu).
Is there any suggestion to accelerate the training?

@csukuangfj
Copy link

csukuangfj commented Jul 4, 2022

  1. What is your vocabulary size?
  2. What is your batch size? And how much data does each batch contain (i.e., what is the total duration )?
  3. Is your GPU usage over 90% ? (You can get such information with watch -n 0.5 nvidia-smi)
  4. What is the value of prune_range?

the The GPU usage does not decrease significantly.

What do you want to express ?

@csukuangfj
Copy link

I'm curious how long the sentence is can be defined as too-long transcripts?

If the sentence is broken into BPE tokens, it is "too long" if the number of BPE tokens is larger than the number of acoustic frames (after subsampling) of this sentence.

@Butterfly-c
Copy link
Author

  1. What is your vocabulary size?
  2. What is your batch size? And how much data does each batch contain (i.e., what is the total duration )?
  3. Is your GPU usage over 90% ? (You can get such information with watch -n 0.5 nvidia-smi)
  4. What is the value of prune_range?

the The GPU usage does not decrease significantly.

What do you want to express ?

Some configuration of my environment is as follows:

1、The vocabulary size is 8245,which contains 6726 Chinese characters,1514 bpe subwords and 5 special symbols.
2、The batch size is 5000 frames (i.e. 50s).
3、As "watch -n 0.5 nvidia-smi" is conducted,the peak volatile gpu-util is over 90%, but most time it is between 80% -90%
4、The pruned_range is 4.

As shown in this paper https://arxiv.org/abs/2206.13236
, the peak GPU usage of fast_rnnt is far below warp-transducer ,and the training time has also been greatly reduced. But as the fast_rnnt conducted in our environment,the training time are not reduced as expected.
As conducted with the same batch size (50s),the statistics of the training time are as follows:
loss times_per_update
warp-transducer 7m40s
fast-rnnt 6m40s

Finally, I have another question about the training time. As shown in the paper, the training time per batch of optimized transducer is over 4 times than fast_rnnt. But the training time per epoch of optimized transducer is just 2 times than fast_rnnt.

I really appreciate for your reply.

@danpovey
Copy link
Collaborator

danpovey commented Jul 5, 2022

I think the comparisons in the paper may have just been for the core RNN-T loss. It does not count the neural net forward, which would not be affected by speedups in the loss computation.

@Butterfly-c
Copy link
Author

I think the comparisons in the paper may have just been for the core RNN-T loss. It does not count the neural net forward, which would not be affected by speedups in the loss computation.

Thanks for your reply, which solved my confusion.

@Butterfly-c
Copy link
Author

Butterfly-c commented Jul 6, 2022

@Butterfly-c Suppose you used pruned loss like this:

simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_smoothed(
      lm=decoder_out,
      am=encoder_out,
      symbols=symbols,
      termination_symbol=blank_id,
      lm_only_scale=lm_scale,
      am_only_scale=am_scale,
      boundary=boundary,
      reduction="sum",
      return_grad=True,
  )

  # ranges : [B, T, prune_range]
  ranges = fast_rnnt.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
  )

  # am_pruned : [B, T, prune_range, C]
  # lm_pruned : [B, T, prune_range, C]
  am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
      am=encoder_out, lm=decoder_out, ranges=ranges
  )

  # logits : [B, T, prune_range, C]
  logits = joiner(am_pruned, lm_pruned)

  pruned_loss = fast_rnnt.rnnt_loss_pruned(
      logits=logits,
      symbols=symbols,
      ranges=ranges,
      termination_symbol=blank_id,
      boundary=boundary,
      reduction="sum",
  )

You can dump the bad cases as follows:

if simple_loss - simple_loss != 0:
  simple_input = {"encoder_out" : encoder_out, "decoder_out" : decoder_out, "symbols" : symbols, "boundary": boundary}
  torch.save(simple_input, "simple_bad_case.pt")

if pruned_loss - pruned_loss != 0:
  pruned_input = {"logits" : logits, "ranges" : ranges, "symbols" : symbols, "boundary": boundary}
  torch.save(pruned_input, "pruned_bad_case.pt")

Based on your suggestion, I saved some bad cases. What's interesting is that most of the 'ranges' are all zero tensors.

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows:
decoder_out [1, 2, 8245]
encoder_out [1, 314, 8245]

Is the training loss will become inf when the input and output are unbalanced(i.e. input is far smaller than output) ?
Can you give some explanation?

@Butterfly-c
Copy link
Author

After I filtering the training data as follows, the inf problem has decreased:
1、 label_len > 2
2、 feat_len//label_len > 30

@Butterfly-c
Copy link
Author

Due to the network limitations, I will share the pruned_bad_case.pt latter.

@pkufool
Copy link
Contributor

pkufool commented Jul 6, 2022

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows:
decoder_out [1, 2, 8245]
encoder_out [1, 314, 8245]

Only one sequence has only one symbol? or all the sequences in one batch have only one symbol?
Thanks, this is very valueable infomation for us.

@Butterfly-c
Copy link
Author

Butterfly-c commented Jul 7, 2022

For example, when the training sample is a backgound music, the label is only one symbol. The lm.shape and am.shape are as follows:
decoder_out [1, 2, 8245]
encoder_out [1, 314, 8245]

Only one sequence has only one symbol? or all the sequences in one batch have only one symbol? Thanks, this is very valueable infomation for us.

Based on 40 pruned_bad_case.pts, all of the bad cases are "all the sequences in one batch have only one symbol". And the sum of 'ranges' are all zero tensors.

@pkufool
Copy link
Contributor

pkufool commented Jul 7, 2022

OK, Thanks! That's it. I think our code did not handle S==1 properly, will try to fix it.

@pkufool
Copy link
Contributor

pkufool commented Jul 7, 2022

@Butterfly-c If you have problem uploading your bad cases to github, can you send your bad cases to me via email, wkang.pku@gmail.com. I want them to test my fixes, Thanks!

@Butterfly-c
Copy link
Author

@Butterfly-c If you have problem uploading your bad cases to github, can you send your bad cases to me via email, wkang.pku@gmail.com. I want them to test my fixes, Thanks!

Due to data permissions, I can't share the bad case information until I get permission. The permission is on the way.

@pkufool
Copy link
Contributor

pkufool commented Jul 7, 2022

Ok, I think there won't be any characters and waves in your bad cases, only float and integer numbers. Hope you can get the permissions, I am testing it with random generated bad cases. Thanks.

@Butterfly-c
Copy link
Author

Ok, I think there won't be any characters and waves in your bad cases, only float and integer numbers. Hope you can get the permissions, I am testing it with random generated bad cases. Thanks.

OK, I will contact you as soon as I get the permission.

This was referenced Jul 11, 2022
@Butterfly-c
Copy link
Author

After updating the fast-rnnt to the version of "fix_s_range", the "inf" problem has been fixed. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants