Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline/Online (standalone) ESPnet2 Transducer #4479

Merged
merged 125 commits into from Aug 17, 2022

Conversation

b-flo
Copy link
Member

@b-flo b-flo commented Jun 29, 2022

Hi,

This PR is a re-do of #4032 with streaming capabilities based on WeNet chunk-by-chunk approaches and Icefall implementations.

The custom encoder architecture was kept here but limited to conv1d and conformer blocks. The idea is to support other *-former architecture (branchformer, enformer, k2-conformer, longformer, etc) as blocks to make a custom X-former architecture for offline and streaming ASR. I implemented and tested most of them already but it'll be added in next PRs.

In regards to the reviews in #4032 and change requests:

  • Naming: There may be minor differences but it should be consistent with other models now!
  • Duplication: We should discuss each duplicated module individually in this version. Most of them were duplicated in preparation to future additions but a few may just be for my convenience and may be removed/merged.

NOTE: Everything should work but this PR is a rebase of previous PR with stitched elements from different work branches, it may contains bugs or mistakes. Feel free to correct or point-out any suspicious parts please!

TO DO:

  • Add tests for online ASR (dummy ones for this PR)
  • Refine docs
  • Add missing references

@csukuangfj I would be glad if you or other Icefall members could take a look at the PR!
Also, if you could point-out any missing references to your work/implementation, it would be great! Because we've gone full circle (ESPnet -> WeNet -> Icefall -> ESPnet ...) on some parts, I'm a bit confused on the proper references...

doc/espnet2_tutorial.md Outdated Show resolved Hide resolved
doc/espnet2_tutorial.md Outdated Show resolved Hide resolved
@pyf98
Copy link
Collaborator

pyf98 commented Aug 2, 2022

Thanks for the great PR! I didn't look into the algorithm itself, but I made a few comments about the doc and init just now.

I think it is already well organized. I especially like the flexible design of the encoder which supports different hyper-parameters for different blocks (if my understanding is correct) instead of sharing the same config across all encoder blocks.

@b-flo
Copy link
Member Author

b-flo commented Aug 2, 2022

Thanks a lot @pyf98 and @pengchengguo

I especially like the flexible design of the encoder which supports different hyper-parameters for different blocks (if my understanding is correct) instead of sharing the same config across all encoder blocks.

Your understanding is correct! You can also mix blocks if you want (well in next PRs)! I'll add some ensemble methods and revisit auxiliary losses with intermediate representations for that.

@danpovey
Copy link

danpovey commented Aug 2, 2022

I included support for the simplified attention score computation and BasicNorm from K2 in the default Conformer implementation. I won't add the other modules from their reworked model here but I'll support a k2Conformer in later PRs, alongside other X-former.

@csukuangfj I referenced the pull requests here because PEP8 won't allow longer links in docstrings (referencing commit, file or method won't work). Feel free to propose changes if there are better ways!

Just FYI, one of the changes I made in our Conformer was to remove the normalization from the individual modules inside the conformer layer. I only expect this to work well if you are using the ScaledLinear/ScaledConv1d modules, which learn a scaling factor for each weight and bias. Otherwise it has no way to learn the appropriate scale on each sub-module except for scaling the whole weight matrix, which is difficult for SGD to learn. BasicNorm would not be expected to be a good solution for the normalization for the individual modules, because it does not support an overall scale on the output.

Also, I am working on (still tuning) an optimization method that will learn the parameter scales as part of the optimizer, without requiring the individual scales for weights and biases, so I expect to eventually remove the ScaledLinear and ScaledConv1d (in newer directories), but the recipe will depend on properties of the optimizer.

@b-flo
Copy link
Member Author

b-flo commented Aug 3, 2022

Thanks a lot for the explanation @danpovey !!
To be honest, I didn't expect BasicNorm to be a good replacement candidate for the reason you gave and it was removed/added back multiple times. However, because it was found appropriate (i.e.: same performance at cheapest cost) in some setups, I decided to keep it in the end.

I'm reworking the normalization module definition, I'll add some warning and explanations to the class doc. I'm also testing AdaNorm right now.

Also, I am working on (still tuning) an optimization method that will learn the parameter scales as part of the optimizer, without requiring the individual scales for weights and biases, so I expect to eventually remove the ScaledLinear and ScaledConv1d (in newer directories), but the recipe will depend on properties of the optimizer.

Thanks for the update, I'll keep an eye on the development!

@pyf98
Copy link
Collaborator

pyf98 commented Aug 4, 2022

I have got two questions.

  1. Does it support GPU inference?
  2. Does it support automatic mixed precision training with use_amp: true?

For LibriSpeech, I'm increasing the nonstreaming model size to 120M and extending the number of epochs to 60.

@csukuangfj
Copy link

For LibriSpeech, I'm increasing the nonstreaming model size to 120M and extending the number of epochs to 60.

Does the model need to be so large and does it need to be trained for so many epochs?

We are using a model with about 80 M parameters and training it for 30 epochs on the LibriSpeech dataset in icefall

@b-flo
Copy link
Member Author

b-flo commented Aug 4, 2022

Does it support GPU inference?

It does! If it's not, that's a bug on my part.

Does it support automatic mixed precision training with use_amp: true?

Sorry, not yet, I need to update warp-transducer for that. Let me finish some other things first and I'll work on it (I'll take a look this weekend)!

Does the model need to be so large and does it need to be trained for so many epochs?
We are using a model with about 80 M parameters and training it for 30 epochs on the LibriSpeech dataset in icefall

Good question.
We usually (at least from my experience) prioritize performance in terms of xER over anything else in ESPnet for ASR. We sometimes add many extra epochs or rely on really large architecture to squeeze some .x%.

That being said:

  1. I do think 120M parameters is too much. I don't mind for the offline version, but we should be careful about the number of parameters in the online model.
  2. 60 epochs is OK in comparison to other model training in ESPnet. However, we really need to improve training in terms of stability and efficiency. That will be the focus after next PR, alongside initialization.

@b-flo
Copy link
Member Author

b-flo commented Aug 4, 2022

Small update:

  1. I removed the parts related to initialization as we don't use them in our experiments. It'll be reworked in a later PR.
  2. I refactored the normalization module for future additions/works. I also added RMSNorm and ScaleNorm but it was not extensively tested (I also tried AdaNorm but found it difficult to converge).

@pyf98
Copy link
Collaborator

pyf98 commented Aug 4, 2022

For LibriSpeech, I'm increasing the nonstreaming model size to 120M and extending the number of epochs to 60.

Does the model need to be so large and does it need to be trained for so many epochs?

We are using a model with about 80 M parameters and training it for 30 epochs on the LibriSpeech dataset in icefall

Thanks for the info! I just try to match the original Conformer-Transducer Large config and see how it performs. This would be an interesting investigation. I don't know if 30 epochs is sufficient for Transducer, but at least it is not for our joint CTC/Attention according to previous experiments.

@mergify
Copy link
Contributor

mergify bot commented Aug 10, 2022

This pull request is now in conflict :(

@mergify mergify bot added the conflicts label Aug 10, 2022
@mergify mergify bot removed the conflicts label Aug 11, 2022
@b-flo b-flo merged commit c83abd7 into espnet:master Aug 17, 2022
@b-flo
Copy link
Member Author

b-flo commented Aug 17, 2022

After discussion, I'm merging this PR! It was a long road for this one, thanks to everyone for your help 🎉

Now onto the next items!

@b-flo b-flo deleted the streaming_transducer_v2 branch September 12, 2022 08:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI Travis, Circle CI, etc Documentation ESPnet2 RNNT (RNN) transducer related issue Streaming
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants