Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for K2 pruned transducer loss #5268

Merged
merged 11 commits into from Jul 13, 2023

Conversation

b-flo
Copy link
Member

@b-flo b-flo commented Jun 30, 2023

This PR add capabilities to train with k2 pruned transducer loss. It's working but a WIP.
It's was not a priority but I'm currently doing a major work/rework on the training pipeline (initialization, schedulers, criterions, normalization, loss and auxiliary losses, etc), I think it should be considered.

Also:

  • Use maes instead of default in ErrorCalculator to avoid endless/long loop in some cases (e.g.: unrestricted decoding with constrained model training).
  • Use torch.tensor instead of torch.LongTensor in all decoders' batch_score. The legacy constructor will raise device mismatch error with torch >1.10.

To add:

  • Tests
  • Docs

@mergify mergify bot added the ESPnet2 label Jun 30, 2023
@sw005320 sw005320 added Enhancement Enhancement RNNT (RNN) transducer related issue labels Jun 30, 2023
@sw005320 sw005320 added this to the v.202307 milestone Jun 30, 2023
@b-flo b-flo added ASR Automatic speech recogntion Enhancement Enhancement and removed Enhancement Enhancement labels Jun 30, 2023
@codecov
Copy link

codecov bot commented Jun 30, 2023

Codecov Report

Merging #5268 (8c70533) into master (baaba22) will decrease coverage by 7.87%.
The diff coverage is 35.48%.

❗ Current head 8c70533 differs from pull request most recent head 1d590ff. Consider uploading reports for the commit 1d590ff to get more accurate results

@@            Coverage Diff             @@
##           master    #5268      +/-   ##
==========================================
- Coverage   74.98%   67.12%   -7.87%     
==========================================
  Files         655      658       +3     
  Lines       58572    59144     +572     
==========================================
- Hits        43922    39701    -4221     
- Misses      14650    19443    +4793     
Flag Coverage Δ
test_integration_espnet1 ?
test_integration_espnet2 ?
test_python 66.48% <35.48%> (+1.20%) ⬆️
test_utils 23.17% <ø> (-0.10%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
espnet2/asr_transducer/beam_search_transducer.py 99.18% <ø> (ø)
espnet2/asr_transducer/decoder/rwkv_decoder.py 25.71% <0.00%> (ø)
espnet2/asr_transducer/error_calculator.py 100.00% <ø> (ø)
espnet2/asr_transducer/espnet_transducer_model.py 75.28% <28.30%> (-20.25%) ⬇️
espnet2/asr_transducer/joint_network.py 92.85% <66.66%> (-7.15%) ⬇️
espnet2/asr_transducer/decoder/mega_decoder.py 100.00% <100.00%> (ø)
espnet2/asr_transducer/decoder/rnn_decoder.py 100.00% <100.00%> (ø)
...spnet2/asr_transducer/decoder/stateless_decoder.py 97.36% <100.00%> (ø)
espnet2/tasks/asr_transducer.py 100.00% <100.00%> (ø)

... and 134 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@ftshijt
Copy link
Collaborator

ftshijt commented Jun 30, 2023

Maybe not related, but curious, are you using the latest k2 for this one? I remember we are using an old fixed version for that. Is it possible to change that to up-to-date version?

@b-flo
Copy link
Member Author

b-flo commented Jun 30, 2023

Maybe not related, but curious, are you using the latest k2 for this one?

Yes, I grabbed one of the latest nightly wheel for my env: Version: 1.24.3.dev20230616+cuda11.7.torch1.13.1

I remember we are using an old fixed version for that. Is it possible to change that to up-to-date version?

It is but I thought it was up to the user because of the comments in the install scripts:

# Please update if too old. See https://k2-fsa.org/nightly/, https://anaconda.org/k2-fsa/k2/files
pip_k2_version="1.10.dev20211112"                                                                    
conda_k2_version="1.10.dev20211103"  # Empty indicates latest version

@mergify mergify bot added the Documentation label Jul 1, 2023
@b-flo
Copy link
Member Author

b-flo commented Jul 5, 2023

I think it's in a OK state for now. I need to do additional testing to ensure/fix some general parameters/variables depending on how k2 kernel is set but it can be done in later PRs.

Also, for reference, for a E-Branchformer/RNN trained with either loss kernel:

Kernel batch_bins (accum_grad) max_cached_mem (GB) dev_clean (%WER) dev_other (%WER) test_clean (%WER) test_other (%WER)
warp-transducer 4000000 (8) 37.4G 5.7 16.8 6.0 17.1
k2 pruned (regular) 4000000 (8) 28.84G 5.6 16.7 6.1 17.1
k2 pruned (regular) 8000000 (4) 37.50G 5.7 16.8 6.1 17.1

I don't provide training time here because it's a bit difficult to properly monitor in my current environment.

],
)
def test_model_training_with_k2(k2_params):
pytest.importorskip("k2")
Copy link
Member Author

@b-flo b-flo Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, I'm not sure it's the proper way to skip a single method? It can't be used as decorator and I don't want to make it global.

t_len,
u_len,
)
with autocast(False):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an oversee on my part, it should have been included in #5140 for amp usage. The (auto)castings will be changed when float16 support is merged to warp-transducer.

@b-flo
Copy link
Member Author

b-flo commented Jul 10, 2023

@sw005320 @ftshijt I'll merge this PR if you're OK with it!

beam_size=1,
search_type="default",
beam_size=2,
search_type="maes",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to change it from default to maes?
What is the intention of this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to change it from default to maes?

It is, it won't make a difference for the reporting.

What is the intention of this change?

To control the number of emitted symbols at each timestep in both training and validation.
With k2, we can end-up with a big performance degradation if we enforce an unrestricted decoding without regards to the training condition/parameters (See. https://arxiv.org/pdf/2211.00484.pdf). With maes, it'll be "stable" under all conditions.

Also, in some cases you can end-up with a really long looping procedure with default (e.g.: during early stages with some architecture such as MEGA or if there is a training/decoding condition mismatch as I said). It won't happen with other decoding algorithms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.
I also encountered the issue of the number of emitted symbols, and this change sounds good to me.
Can you leave this discussion somewhere in the source code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done but I'm not sure it's in a right place nor in a right form..!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sw005320 Is that okay for you like that or do you want me to add a note in the tutorial doc instead? I'll merge after that

Copy link
Contributor

@sw005320 sw005320 Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. the current one is OK
(I'm trying to fix the CI issue...)

@b-flo b-flo merged commit 9bc75d5 into espnet:master Jul 13, 2023
10 of 23 checks passed
@b-flo b-flo deleted the k2_pruned_transducer branch July 13, 2023 14:02
@b-flo b-flo mentioned this pull request Jul 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ASR Automatic speech recogntion Documentation Enhancement Enhancement ESPnet2 RNNT (RNN) transducer related issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants