Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer Chainer #774

Merged
merged 47 commits into from
Jul 30, 2019
Merged

Transformer Chainer #774

merged 47 commits into from
Jul 30, 2019

Conversation

Fhrozen
Copy link
Member

@Fhrozen Fhrozen commented May 29, 2019

This is the second part of the updates for the transformer with Chainer:

@sw005320 sw005320 added the Enhancement Enhancement label May 29, 2019
@ShigekiKarita
Copy link
Member

I wish you also fix #755 in chainer

@Fhrozen
Copy link
Member Author

Fhrozen commented Jun 11, 2019

I just added the fixes for most of the problem with pytorch backend.
Also trained chainer backend with mtl_alpha 0.3 without problem.
I set the patience to 10 and the model was trained for 72 epochs.

acc

Currently, I am training a LM for testing the joint decoding and finish with this PR.
BTW, the train.yml file is modified with a ln command inside the run.sh, but it will be better to delete and directly call them from the run.sh. let me know about this.

@Fhrozen
Copy link
Member Author

Fhrozen commented Jun 11, 2019

BTW, the CER/WER implemented is a based in greedy search. I am currently using this due to the large number of epochs employed in the transformer. Let me know if this is Ok, or should be a beam_search similar to that implemented in the RNN decoder.

@sw005320
Copy link
Contributor

sw005320 commented Jun 11, 2019

CTC: I think this is fine.
Attention: Making transformer beam search work on GPU requires additional work, and this is still fine, but people may confuse it. We may not need it for now.

egs/wsj/asr1/conf/tuning/train_pytorch_transformer.yaml Outdated Show resolved Hide resolved
espnet/nets/pytorch_backend/e2e_asr_transformer.py Outdated Show resolved Hide resolved
espnet/nets/pytorch_backend/e2e_asr_transformer.py Outdated Show resolved Hide resolved
espnet/nets/pytorch_backend/e2e_asr_transformer.py Outdated Show resolved Hide resolved
espnet/nets/pytorch_backend/e2e_asr_transformer.py Outdated Show resolved Hide resolved
espnet/nets/pytorch_backend/e2e_asr_transformer.py Outdated Show resolved Hide resolved
espnet/nets/pytorch_backend/e2e_asr_transformer.py Outdated Show resolved Hide resolved
test/test_e2e_transformer.py Outdated Show resolved Hide resolved
@sw005320
Copy link
Contributor

@Fhrozen, can we make the CER computation part as a function, put it on some common directory, and call it at both transformer/RNN in both chainer/pytorch backend?

@Fhrozen
Copy link
Member Author

Fhrozen commented Jun 18, 2019

@ShigekiKarita I just finished with the requested tests:

  • w/o ctc w/o lm ctc_weight=0.0, lm_weight=0.0:
write a CER (or TER) result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_eval92_decode_chainer_transf_noctc_nolm_lm_word65000/result.txt                                  
|     SPKR       |     # Snt          # Wrd     |     Corr           Sub            Del            Ins            Err          S.Err     |                                                                  
|     Sum/Avg    |      333           33341     |     94.1           1.4            4.5            1.1            7.0           77.8     |                                                                  
write a WER result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_eval92_decode_chainer_transf_noctc_nolm_lm_word65000/result.wrd.txt                                       
|     SPKR       |     # Snt          # Wrd     |     Corr             Sub            Del             Ins            Err           S.Err     |                                                              
|     Sum/Avg    |      333            5643     |     87.1             8.8            4.1             1.6           14.5            73.6     |
write a CER (or TER) result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_dev93_decode_chainer_transf_noctc_nolm_lm_word65000/result.txt                                   
|     SPKR       |     # Snt         # Wrd     |     Corr            Sub            Del           Ins            Err          S.Err     |                                                                   
|     Sum/Avg    |      503          48634     |     92.9            2.1            5.0           1.1            8.2           83.1     |                                                                   
write a WER result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_dev93_decode_chainer_transf_noctc_nolm_lm_word65000/result.wrd.txt                                        
|     SPKR       |     # Snt          # Wrd     |     Corr            Sub             Del            Ins            Err           S.Err     |                                                               
|     Sum/Avg    |      503            8234     |     84.1           11.3             4.7            1.7           17.6            82.1     |  
  • w/o ctc ctc_weight=0.0, lm_weight=1.0:
write a CER (or TER) result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_eval92_decode_chainer_transf_noctc_lm_word65000/result.txt                                      
|    SPKR       |     # Snt         # Wrd     |    Corr            Sub           Del           Ins            Err         S.Err     |                                                                       
|    Sum/Avg    |      333          33341     |     5.1            0.0          94.9           0.0           94.9          99.4     |                                                                       

write a WER result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_eval92_decode_chainer_transf_noctc_lm_word65000/result.wrd.txt   
|     SPKR       |     # Snt         # Wrd     |     Corr           Sub            Del            Ins            Err          S.Err     |                                                                   
|     Sum/Avg    |      333           5643     |      6.5           1.1           92.4            0.0           93.5           99.1     |  

write a CER (or TER) result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_dev93_decode_chainer_transf_noctc_lm_word65000/result.txt                                  
|    SPKR       |    # Snt          # Wrd     |    Corr           Sub           Del            Ins           Err         S.Err     |                                                                        
|    Sum/Avg    |     503           48634     |     5.7           0.1          94.2            0.0          94.3          99.4     |                                                                        

write a WER result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_dev93_decode_chainer_transf_noctc_lm_word65000/result.wrd.txt                                             
|     SPKR       |     # Snt        # Wrd     |     Corr            Sub            Del           Ins            Err          S.Err     |                                                                    
|     Sum/Avg    |      503          8234     |      7.2            1.1           91.7           0.0           92.9           99.2     |   
  • w/o lm ctc_weight=0.3, lm_weight=0.0:
write a CER (or TER) result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_eval92_decode_chainer_transf_nolm_lm_word65000/result.txt
|    SPKR       |    # Snt          # Wrd     |    Corr           Sub           Del            Ins           Err         S.Err     |
|    Sum/Avg    |     333           33341     |    97.1           1.5           1.4            1.0           3.9          80.2     |
write a WER result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_eval92_decode_chainer_transf_nolm_lm_word65000/result.wrd.txt
|     SPKR       |     # Snt        # Wrd     |     Corr            Sub            Del           Ins            Err          S.Err     |
|     Sum/Avg    |      333          5643     |     89.0           10.2            0.8           1.5           12.5           75.4     |
write a CER (or TER) result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_dev93_decode_chainer_transf_nolm_lm_word65000/result.txt
|    SPKR       |    # Snt         # Wrd     |    Corr            Sub           Del           Ins           Err         S.Err     |
|    Sum/Avg    |     503          48634     |    95.8            2.2           2.1           1.0           5.2          83.1     |
write a WER result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_dev93_decode_chainer_transf_nolm_lm_word65000/result.wrd.txt
|     SPKR       |    # Snt         # Wrd     |     Corr           Sub            Del            Ins           Err          S.Err     |
|     Sum/Avg    |     503           8234     |     85.8          12.6            1.5            1.5          15.7           81.7     |

I am using a model trained for 68 epochs (patience=10). So the averaged model will come from epochs 59 ~68.
Let me know if any additional test is required.

Copy link

@aonotas aonotas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work.
I add some comments.

When w/o ctc ctc_weight=0.0, lm_weight=1.0:, the CER result is bad?

 |   Corr           Sub            Del            Ins            Err          S.Err     |                                                                   
 |    6.5           1.1           92.4            0.0           93.5           99.1     |

if self.flag_return:
loss_ctc = None
return self.loss, loss_ctc, loss_att, acc
else:
return self.loss

def recognize(self, x_block, recog_args, char_list=None, rnnlm=None):
def recognize_beam2(self, x_block, recog_args, char_list=None, rnnlm=None, use_jit=False):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method is necessary?
I think this PR does not use recognize_beam2 in other codes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left this on purpose. Just in case the recognize_beam didnot work, but I will be removing before merge.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you.

@@ -78,37 +78,48 @@ class CustomUpdater(training.StandardUpdater):
def __init__(self, train_iter, optimizer, converter, device, accum_grad=1):
super(CustomUpdater, self).__init__(
train_iter, optimizer, converter=converter, device=device)
self.count = 0
self.forward_count = 0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel changing from count to forward_count seems to have side effects.
Do you want to fix the code of accum_grad?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not find any side effects on the accum_grad with forward_count but I will check it once more. Could you explain me which possible effect appear.?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry. This is my concern.
I think CustomUpdater andCustomParallelUpdater are common components.
Is this modification necessary for the Transformer PR?
Actually, I'm not sure why this modification is necessary. (This is just comment.)

If this modification for CustomUpdater andCustomParallelUpdater is not related to Transformer method, I feel you can separate PR into different PRs.
But I'm not the main contributor of ESPNET. so I'm not sure we should separate PR or not.

@Fhrozen
Copy link
Member Author

Fhrozen commented Jun 19, 2019

@aonotas thank you for your support and comments. I will be reflecting the modifications later.

@ShigekiKarita
Copy link
Member

@aonotas Thanks for your help!

Thank you for your work.
I add some comments.

When w/o ctc ctc_weight=0.0, lm_weight=1.0:, the CER result is bad?

 |   Corr           Sub            Del            Ins            Err          S.Err     |                                                                   
 |    6.5           1.1           92.4            0.0           93.5           99.1     |

Unfortunately, this is expected. This strange behaviour is already known in pytorch impl Transformer. We found LM integration without CTC seems to be difficult in WSJ.

@aonotas
Copy link

aonotas commented Jun 19, 2019

We found LM integration without CTC seems to be difficult in WSJ.

Wow, this is interesting. Thank you for your information.

@sw005320
Copy link
Contributor

We found LM integration without CTC seems to be difficult in WSJ.

@creatorscan may fix it.
He told me that he found a bug for this.

@Fhrozen
Copy link
Member Author

Fhrozen commented Jun 20, 2019

I just finished to test the chainer model with ngpu=2 & accu_grad=2
The model was trained for 71 epochs (early stop with patience=10)
Training time: 25hrs (2 GPUs GTX TITAN X & CUDA 10)

write a CER (or TER) result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_eval92_decode_chainer_transformer_lm_word65000/result.txt
|    SPKR       |    # Snt          # Wrd     |    Corr           Sub           Del            Ins           Err         S.Err     |
|    Sum/Avg    |     333           33341     |    98.1           1.0           0.9            0.7           2.6          55.6     |
write a WER result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_eval92_decode_chainer_transformer_lm_word65000/result.wrd.txt
|     SPKR       |     # Snt        # Wrd     |     Corr            Sub            Del           Ins            Err          S.Err     |
|     Sum/Avg    |      333          5643     |     95.2            4.4            0.4           1.0            5.8           47.7     |

write a CER (or TER) result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_dev93_decode_chainer_transformer_lm_word65000/result.txt                                        
|    SPKR       |    # Snt         # Wrd     |    Corr            Sub           Del           Ins           Err         S.Err     |                                                                        
|    Sum/Avg    |     503          48634     |    97.0            1.5           1.5           0.8           3.8          65.0     |                                                                        
write a WER result in exp/train_si284_chainer_train_chainer_transformer_no_preprocess/decode_test_dev93_decode_chainer_transformer_lm_word65000/result.wrd.txt                                             
|     SPKR       |    # Snt         # Wrd     |     Corr           Sub            Del            Ins           Err          S.Err     |                                                                    
|     Sum/Avg    |     503           8234     |     92.6           6.5            0.9            1.5           8.9           59.4     |   

The result did not change alot, only the dev has a slightly reduction CER 0.2 and WER 0.1.
I will finishing the CER computation and additional small fixes by the weekend.

@codecov
Copy link

codecov bot commented Jul 4, 2019

Codecov Report

Merging #774 into v.0.5.0 will increase coverage by <.01%.
The diff coverage is 64.81%.

Impacted file tree graph

@@             Coverage Diff             @@
##           v.0.5.0     #774      +/-   ##
===========================================
+ Coverage    51.07%   51.07%   +<.01%     
===========================================
  Files          102      110       +8     
  Lines        10957    11133     +176     
===========================================
+ Hits          5596     5686      +90     
- Misses        5361     5447      +86
Impacted Files Coverage Δ
espnet/nets/chainer_backend/rnn/decoders.py 89.43% <ø> (ø)
espnet/nets/pytorch_backend/e2e_asr_transformer.py 70.63% <ø> (ø) ⬆️
espnet/nets/chainer_backend/rnn/attentions.py 98.18% <ø> (ø)
...pnet/nets/chainer_backend/transformer/attention.py 100% <ø> (ø)
espnet/nets/chainer_backend/rnn/encoders.py 98.29% <ø> (ø)
espnet/asr/chainer_backend/asr.py 0% <0%> (ø) ⬆️
...nets/chainer_backend/transformer/optimizer_rule.py 0% <0%> (ø)
espnet/asr/pytorch_backend/asr.py 0% <0%> (ø) ⬆️
...r_backend/transformer/positionwise_feed_forward.py 100% <100%> (ø)
.../nets/chainer_backend/transformer/decoder_layer.py 100% <100%> (ø)
... and 24 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8cbde38...f0863d6. Read the comment docs.

@sw005320
Copy link
Contributor

@Fhrozen, what is the status?

@Fhrozen
Copy link
Member Author

Fhrozen commented Jul 18, 2019

Only Need to add CER computation to RNN for finishing this PR.
I will be doing this once i got back to japan on the weekend (IJCNN conference finish on 19).

@sw005320
Copy link
Contributor

OK. Enjoy IJCNN!

@Fhrozen
Copy link
Member Author

Fhrozen commented Jul 30, 2019

@sw005320 @kan-bayashi , pls check it for merge before someone else updates v.0.5. ;)

Copy link
Contributor

@sw005320 sw005320 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you add a test for this PR?

egs/wsj/asr1/conf/tuning/decode_chainer_transformer.yaml Outdated Show resolved Hide resolved
def update(self):
self.update_core()
if self.forward_count == 0:
self.iteration += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you do it here?

self.iteration += 1
seems to increase the iterations. Do you need this? If so, could add comments about this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was related to #777, I suppose I need to add it as comment

espnet/nets/chainer_backend/transformer/attention.py Outdated Show resolved Hide resolved
@sw005320 sw005320 merged commit 19b7916 into espnet:v.0.5.0 Jul 30, 2019
@Fhrozen Fhrozen deleted the pr-transf-chainer branch July 30, 2019 08:56
@kan-bayashi kan-bayashi changed the title [WIP] Transformer Chainer Transformer Chainer Jul 30, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Enhancement Enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants