New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for K2 pruned transducer loss #5268
Conversation
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #5268 +/- ##
==========================================
- Coverage 74.98% 67.12% -7.87%
==========================================
Files 655 658 +3
Lines 58572 59144 +572
==========================================
- Hits 43922 39701 -4221
- Misses 14650 19443 +4793
Flags with carried forward coverage won't be shown. Click here to find out more.
... and 134 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Maybe not related, but curious, are you using the latest k2 for this one? I remember we are using an old fixed version for that. Is it possible to change that to up-to-date version? |
Yes, I grabbed one of the latest nightly wheel for my env:
It is but I thought it was up to the user because of the comments in the install scripts: # Please update if too old. See https://k2-fsa.org/nightly/, https://anaconda.org/k2-fsa/k2/files
pip_k2_version="1.10.dev20211112"
conda_k2_version="1.10.dev20211103" # Empty indicates latest version |
I think it's in a OK state for now. I need to do additional testing to ensure/fix some general parameters/variables depending on how k2 kernel is set but it can be done in later PRs. Also, for reference, for a E-Branchformer/RNN trained with either loss kernel:
I don't provide training time here because it's a bit difficult to properly monitor in my current environment. |
], | ||
) | ||
def test_model_training_with_k2(k2_params): | ||
pytest.importorskip("k2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum, I'm not sure it's the proper way to skip a single method? It can't be used as decorator and I don't want to make it global.
t_len, | ||
u_len, | ||
) | ||
with autocast(False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an oversee on my part, it should have been included in #5140 for amp usage. The (auto)castings will be changed when float16 support is merged to warp-transducer
.
beam_size=1, | ||
search_type="default", | ||
beam_size=2, | ||
search_type="maes", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it OK to change it from default
to maes
?
What is the intention of this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it OK to change it from default to maes?
It is, it won't make a difference for the reporting.
What is the intention of this change?
To control the number of emitted symbols at each timestep in both training and validation.
With k2, we can end-up with a big performance degradation if we enforce an unrestricted decoding without regards to the training condition/parameters (See. https://arxiv.org/pdf/2211.00484.pdf). With maes
, it'll be "stable" under all conditions.
Also, in some cases you can end-up with a really long looping procedure with default
(e.g.: during early stages with some architecture such as MEGA or if there is a training/decoding condition mismatch as I said). It won't happen with other decoding algorithms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification.
I also encountered the issue of the number of emitted symbols, and this change sounds good to me.
Can you leave this discussion somewhere in the source code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done but I'm not sure it's in a right place nor in a right form..!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sw005320 Is that okay for you like that or do you want me to add a note in the tutorial doc instead? I'll merge after that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. the current one is OK
(I'm trying to fix the CI issue...)
This PR add capabilities to train with k2 pruned transducer loss. It's working but a WIP.
It's was not a priority but I'm currently doing a major work/rework on the training pipeline (initialization, schedulers, criterions, normalization, loss and auxiliary losses, etc), I think it should be considered.
Also:
maes
instead ofdefault
inErrorCalculator
to avoid endless/long loop in some cases (e.g.: unrestricted decoding with constrained model training).torch.tensor
instead oftorch.LongTensor
in all decoders'batch_score
. The legacy constructor will raise device mismatch error with torch >1.10.To add: