-
Notifications
You must be signed in to change notification settings - Fork 25.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Wav2Vec2] PyCTCDecode Integration to support language model boosted decoding #14339
[Wav2Vec2] PyCTCDecode Integration to support language model boosted decoding #14339
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the design as suggested. Just have one comment on the save_pretrained
method.
""" | ||
|
||
self.feature_extractor.save_pretrained(save_directory) | ||
self.tokenizer.save_pretrained(save_directory) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decoder should also be saved here, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for writing down your thoughts. I agree with everything said here, and pyctcdecode
looks like a great tool, happy to use it.
I would not advocate for a Wav2Vec2ProcessorWithLM
, however. I'd vote to have a single Wav2Vec2Processor
instead. I guess you decided to split them as you didn't want to add the unnecessary overhead to Wav2Vec2Processor
for users that did not want to use the language model?
In that case, I'd favor either:
- Loading the LM on the fly when calling the
decode
method with the language model (maybe as a newdecode_with_lm
method?). - Passing an additional argument
with_lm_decoding
(either to the__init__
, to prevent on-the-fly instantiation, or to thedecode
)
I personally think it would be less awkward from a user's perspective to have everything bundled in a single processor.
Overall, super down and excited for this PR! Let's get those improvements on WER!
Not convinced by grouping everything together in one class which will sometimes have an additional method that works and sometimes not (depending on the env). The code is probably also going to be hard to read if all the imports have to be contained in a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking previous suggestions into account, I think combining CTC decoders with the existing Wav2VecProcessor
is a good idea. This workflow seems pretty natural to me:
processor = Wav2VecProcessor(feat_extractor, tokenizer)
processor.decode(logits)
>> Warning: `lm_decoder` is not specified, using a greedy decoder
lm_decoder = Wav2Vec2LMDecoder("kenlm/librispeech-100h-4gram")
processor = Wav2Vec2Processor(feat_extractor, tokenizer, lm_decoder=lm_decoder)
processor.decode(logits)
(sneaky suggestion to rename ctc_decoder
here :))
return self.ctc_decoder.decode(logits.numpy()) | ||
|
||
@contextmanager | ||
def as_target_processor(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the processor now has 3 modes of operation, maybe this design can be deprecated in favor of encode()
and decode()
?
Regarding the PyCTCDecode integration I think it makes sense to start step-by-step and simple add a Once this PR is approved we can move forward with this PR. |
Thanks a lot for the feedback regarding the design choices @sgugger, @LysandreJik and @anton-l . I agree more with @sgugger here, but I think both implementations are valid and have pro/cons. To give some more background for a better design decision:
with IMO, the main reason why I think a new class would be better is because the class will be very experimental and will most likely change in the future (add other backend libraries, other language models, ...). Langauge model support for decoding is by no means always necessary or needed for ASR, so I can see lots of people just keep using Some other reasons why I prefer
Both classes would have the exact same API and can be replaced one-by-one. So for users wanting to decode with a language model everything is bundled in a single class, namely I don't see the huge advantage of having only a single
|
When I said on the fly, I meant it would be loaded the first time; every subsequent operation would use the previously loaded model. But I don't have a strong opinion, and I understand your perspective. Good for me to go with the new class! |
After looking through the previous PRs conserning But I'm interested in discussing (perhaps not in this PR) how we can evolve the current processing design, since only the feature extractor is universally required for speech models now, and the tokenizer and LM can be applied separately, depending on the target task. |
Real world demo example for a SOTA spanish wav2vec2 model: https://huggingface.co/patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm -> seems to give a nice 10-20% WER improvement |
Final user API: import torch
import torchaudio.functional as F
-from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
+from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
from datasets import load_dataset
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).n
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
-processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
-prediction_ids = torch.argmax(logits, dim=-1)
-transcription = processor.batch_decode(prediction_ids)
+transcription = processor.batch_decode(logits.cpu().numpy()).text
print(transcription) |
|
||
|
||
@dataclass | ||
class Wav2Vec2DecoderWithLMOutput(ModelOutput): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm planning on provided more outputs in the future for time-stamped word outputs, etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understood from your comment here #14487 (review) you'd rather we not abstract model outputs, which I understand. I think this should be the case here too, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, just understood that text
was a single output and not an overload of a previously defined output. You can ignore my comment :)
cls._set_language_model_attribute(decoder, attribute, value) | ||
|
||
# make sure that decoder's alphabet and tokenizer's vocab match in content | ||
missing_decoder_tokens = cls.get_missing_alphabet_tokens(decoder, tokenizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aggressive check to make sure the model's vocabulary matches the decoder's alphabet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! I've left some comments as there is some cleaning up to do in some docstrings, the setup and the install instructions in the various CI jobs.
@@ -83,6 +83,7 @@ jobs: | |||
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision] | |||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html | |||
- run: pip install tensorflow_probability | |||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this package? The import error indicate to do pip install pyctcdecode
later on and do not give any instruction to install this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pyctcdecode
optionally depends on kenlm if the user would like to use a kenlm
language model. In the future, there will probably be more language models that don't require kenlm
.
So IMO, it's the responsibility of the pyctcdecode
package to throw a good error in case a user requests pyctcdecode
with a kenlm
language model. However since at the moment the only language model support is based on kenlm
I can also throw a nice error message on our side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a plan to add a realy Python package for kenlm
? This is a bit heavy :-(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm really not sure the kenlm
repo doesn't seem to be super active: https://github.com/kpu/kenlm .
It is however the by far most library for language model supported ASR.
Flashlight uses it: https://github.com/flashlight/flashlight/tree/main/bindings/python#dependencies among many other libraries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Espnet uses it as well: https://github.com/espnet/espnet/blob/master/espnet/nets/scorers/ngram.py
@@ -281,6 +287,7 @@ jobs: | |||
- run: pip install --upgrade pip | |||
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision] | |||
- run: pip install tensorflow_probability | |||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need it in the TF tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a processor
and therefore framework independent - it's written in pure Python.
.circleci/config.yml
Outdated
@@ -701,6 +717,7 @@ jobs: | |||
- v0.4-{{ checksum "setup.py" }} | |||
- run: pip install --upgrade pip | |||
- run: pip install .[torch,testing,sentencepiece,onnxruntime] | |||
- run: pip install https://github.com/kpu/kenlm/archive/master.zip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need it in the ONNX tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed it from ONNX
docs/source/model_doc/wav2vec2.rst
Outdated
@@ -143,3 +136,19 @@ FlaxWav2Vec2ForPreTraining | |||
|
|||
.. autoclass:: transformers.FlaxWav2Vec2ForPreTraining | |||
:members: __call__ | |||
|
|||
Wav2Vec2 specific outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In all other doc pages, those go before the models, so let's leave them here for now. We can decide to switch them after the model but all models together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good - reverting this
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
…onplaten/transformers into pyctcdecode_integration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great integration! Although I feel like pyctcdecode's "magic options" can be documented a bit more verbosely, so I left some suggestions :)
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
def get_missing_alphabet_tokens(decoder, tokenizer): | ||
# we need to make sure that all of the tokenizer's except the special tokens | ||
# are present in the decoder's alphabet. Retrieve missing alphabet token | ||
# from decoder | ||
tokenizer_vocab_list = [t.lower() for t in tokenizer.get_vocab().keys()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t.lower()
won't allow us to have an all-uppercase vocab (e.g. the official LibriSpeech LMs for eval are uppercase https://www.openslr.org/11)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally right! Thanks a lot for catching this! Not sure what I was thinking here :D
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
…onplaten/transformers into pyctcdecode_integration
Yes, we should definitely make a notebook about it |
@LysandreJik @sgugger - think it's ready for a final review. I've now made sure that this LM boosted ASR is only tested for the TF, Flax and PT tests, but not for the ONNX & Hub tests. I've also added integration tests for TF and Flax. @sgugger I don't really see how to get rid of |
Thanks for containing the addition of kenlm. The documentation on how to run the tests locally should get an update as we are very very far now from just needing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, great work @patrickvonplaten
|
||
|
||
@dataclass | ||
class Wav2Vec2DecoderWithLMOutput(ModelOutput): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understood from your comment here #14487 (review) you'd rather we not abstract model outputs, which I understand. I think this should be the case here too, right?
…decoding (huggingface#14339) * up * up * up * make it cleaner * correct * make styhahalal * add more tests * finish * small fix * make style * up * tryout to solve cicrle ci * up * fix more tests * fix more tests * apply sylvains suggestions * fix import * correct docs * add pyctcdecode only to speech tests * fix more tests * add tf, flax and pt tests * add pt * fix last tests * fix more tests * Apply suggestions from code review * change lines * Apply suggestions from code review Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * correct tests * correct tests * add doc string Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
Draft to integrate pyctcdecode into 🤗 Transformers
This will is a short doc to explain all the important aspects of a possible integration of pyctcdecode into 🤗 Transformers
What is LM-boosted Decoding?
In LM-boosted Decoding an acoustic model (Wav2Vec2) in trained on some speech data and independently of this training a language model (e.g. KenLM n-gram) is trained on some text in the same language than the speech data. Then during evaluation, the language model supports the acoustic model in predicted the transcribed words via beam search decoding. To be more precise, the output (log-)probability matrix of the acoustic model - being a [timesteps x log-prob for each subword token] matrix - is fed into a beam search decoder and by means of a language model (
P(subword token | prev subword token)
), the overall best subword token sequence is chosen using a beam search algorithm.Why do we need LM-boosted Decoding for Speech?
LM-boosted decoding is still the or one of the state-of-the-art approaches for ASR systems in terms of Word-error-rate (WER) performance. The other upcoming system is an end-to-end approach where the language model is learned together with the acoustic model. This approach includes:
The advantage of LM-boosted decoding is:
The disadvantages are:
torch.nn
).Why pyctcdecode?
We could implement the whole CTC beam search algorithm ourselves in
transformers
or a separate library, but it would look very similar to already existing libraries and in the spirit of open-source it's usually better to together improve existing libraries instead of duplicating work. There are three libraries for CTC beam search decoding that I analysed:Given this analysis and that the spirit of
transformers
is readability and easy-to-contribute to, 2.) makes by far the most sense to be considered for an integration totransformers
IMO. It would be great if we manage to collaborate well with https://github.com/kensho-technologies/pyctcdecode on design choices and integrations, but we can also in the worst-case scenario (if for some reason our vision differs too strongly from https://github.com/kensho-technologies/pyctcdecode) fork the repo and shape it to how we would need it - it has a MIT license. However, the library looks quite nice to me and I'm also confident that we can start a fruitful collaboration bothpyctcdecode
and we can profit from.Integration into Hugging Face's
transformers
A couple of important requirements for a nice integration with
transformers
are:Keeping in mind that LM boosted Decoding requires the output log-probs of the acoustic model (Wav2Vec2ForCTC) as well as a dictionary and a language model there are two clean ways of integrating the feature IMO:
1.) - We add a new
Wav2Vec2CTCDecoder
class that replaces theWav2Vec2CTCTokenizer
and can be used just asWav2Vec2CTCTokenizer
withinWav2Vec2Processor
. Since this class would require the vocabulary ofWav2Vec2CTCTokenizer
we would probably have to add aself.tokenizer = Wav2Vec2CTCTokenizer(...)
attribute inWav2Vec2CTCDecoder
which would create a bit too much abstraction IMO (Wav2Vec2Processor -> Wav2Vec2CTCDecoder -> Wav2Vec2Tokenizer).2.) - We add a new
Wav2Vec2ProcessorWithLM
class that replacesWav2Vec2Processor
. It essentially just adds aself.decoder = ...
toWav2Vec2Processor
and thebatch_decode()
anddecode()
methods now run LM-boosted decoding instead of the previous "tokenizer-only" decoding.=> IMO 2.) is the better approach as it requires less abstraction and is also "safer" in that we can simply say that
Wav2Vec2ProcessorWithLM
is an experimental class that can be used to replaceWav2Vec2Processor
.This PR implements more or less everything that is required on the
transformers
side for 2).So the change in API that I'm aiming for would look as follows:
Thinking a bit ahead here, IMO it would also be totally fine to have both a
Wav2Vec2Processor
and aWav2Vec2ProcessorWithLM
work correctly with anAutoProcessor
class. We could just add a newprocessor_type
attribute to theconfig.json
so that the correct processor class is loaded depending on theconfig.json
of the model. We could use a similar general design (ideally even a bit cleaner) as is used here.Feature additions to pyctcdecode for target API
It would be great if together with
pyctcdecode
we could add an optionalfrom_from_hf_hub(...)
functionality for their BeamSearchDecoder class(es). This should be pretty simple to do with huggingface_hub and should also in general make it much easier forpyctcdecode
to load and save models online (for free). This is to be discussed.In a first step, it would be easiest to focus on fully supporting download and upload of KenLM language models for seamless
KenLM
-ngram boosted decoding. KenLM-ngram boosted decoding yielded some nice improvements in my experiments hereIn a next step, we could then look into support for
transformer
LM models in PyCTCDecode (makepyctcdecode's
beam search compatible with ourAutoModelForCausalLM
models) and also addload_from_hub(...)
functionality for this inpyctcdecode
.Other possible improvements could include:
logits
and sampling rate of the model