-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Wav2Vec2] PyCTCDecode Integration to support language model boosted decoding #14339
Merged
patrickvonplaten
merged 34 commits into
huggingface:master
from
patrickvonplaten:pyctcdecode_integration
Dec 8, 2021
Merged
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
6de1445
up
patrickvonplaten b68faa9
up
patrickvonplaten 8294efa
up
patrickvonplaten 6ec01c2
make it cleaner
patrickvonplaten 52afd82
correct
patrickvonplaten e3b0fde
make styhahalal
patrickvonplaten e7eb51c
add more tests
patrickvonplaten ff0de09
finish
patrickvonplaten 6296938
small fix
patrickvonplaten 84bfdf3
make style
patrickvonplaten 4caf406
up
patrickvonplaten d59b594
tryout to solve cicrle ci
patrickvonplaten ead3873
Merge branch 'master' into pyctcdecode_integration
patrickvonplaten 682b258
up
patrickvonplaten 6320a5a
Merge branch 'pyctcdecode_integration' of https://github.com/patrickv…
patrickvonplaten 53aaeff
fix more tests
patrickvonplaten 7b24cdc
fix more tests
patrickvonplaten f3648f6
apply sylvains suggestions
patrickvonplaten f39f02c
fix import
patrickvonplaten 19a1301
correct docs
patrickvonplaten 88783e3
add pyctcdecode only to speech tests
patrickvonplaten 51f3dc7
fix more tests
patrickvonplaten ceb6ea2
add tf, flax and pt tests
patrickvonplaten e2b19af
add pt
patrickvonplaten b1ba5dd
fix last tests
patrickvonplaten a52f319
fix more tests
patrickvonplaten 66dd6d8
Apply suggestions from code review
patrickvonplaten d9cdb5e
change lines
patrickvonplaten b93b954
Merge branch 'pyctcdecode_integration' of https://github.com/patrickv…
patrickvonplaten 0fe15e1
Apply suggestions from code review
patrickvonplaten 2382b92
correct tests
patrickvonplaten 8e70208
Merge branch 'pyctcdecode_integration' of https://github.com/patrickv…
patrickvonplaten b46df6b
correct tests
patrickvonplaten 776d152
add doc string
patrickvonplaten File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,3 +151,195 @@ def as_target_processor(self): | |
self.current_processor = self.tokenizer | ||
yield | ||
self.current_processor = self.feature_extractor | ||
|
||
|
||
class Wav2Vec2ProcessorWithLM: | ||
r""" | ||
Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor, a Wav2Vec2 CTC tokenizer and a language model into a single | ||
processor for language model boosted speech recognition decoding. | ||
|
||
:class:`~transformers.Wav2Vec2Processor` offers all the functionalities of | ||
:class:`~transformers.Wav2Vec2FeatureExtractor` and :class:`~transformers.Wav2Vec2CTCTokenizer`. See the docstring | ||
of :meth:`~transformers.Wav2Vec2Processor.__call__` and :meth:`~transformers.Wav2Vec2Processor.decode` for more | ||
information. | ||
|
||
Args: | ||
feature_extractor (:obj:`Wav2Vec2FeatureExtractor`): | ||
An instance of :class:`~transformers.Wav2Vec2FeatureExtractor`. The feature extractor is a required input. | ||
tokenizer (:obj:`Wav2Vec2CTCTokenizer`): | ||
An instance of :class:`~transformers.Wav2Vec2CTCTokenizer`. The tokenizer is a required input. | ||
""" | ||
|
||
def __init__(self, feature_extractor, tokenizer, ctc_decoder): | ||
if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor): | ||
raise ValueError( | ||
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}" | ||
) | ||
if not isinstance(tokenizer, Wav2Vec2CTCTokenizer): | ||
raise ValueError( | ||
f"`tokenizer` has to be of type {Wav2Vec2CTCTokenizer.__class__}, but is {type(tokenizer)}" | ||
) | ||
|
||
self.feature_extractor = feature_extractor | ||
self.tokenizer = tokenizer | ||
self.ctc_decoder = ctc_decoder | ||
self.current_processor = self.feature_extractor | ||
|
||
def save_pretrained(self, save_directory): | ||
""" | ||
Save a Wav2Vec2 feature_extractor object and Wav2Vec2 tokenizer object to the directory ``save_directory``, so | ||
that it can be re-loaded using the :func:`~transformers.Wav2Vec2Processor.from_pretrained` class method. | ||
|
||
.. note:: | ||
|
||
This class method is simply calling | ||
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` and | ||
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the | ||
docstrings of the methods above for more information. | ||
|
||
Args: | ||
save_directory (:obj:`str` or :obj:`os.PathLike`): | ||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will | ||
be created if it does not exist). | ||
""" | ||
|
||
self.feature_extractor.save_pretrained(save_directory) | ||
self.tokenizer.save_pretrained(save_directory) | ||
|
||
@staticmethod | ||
def _load_ctc_decoder(pretrained_model_name_or_path, vocab_dict, **kwargs): | ||
from pyctcdecode import Alphabet, BeamSearchDecoderCTC | ||
|
||
# i.) build alphabet | ||
# check https://github.com/kensho-technologies/pyctcdecode/blob/94dfdae1d18ad95e799286173826aec2dec9a6b2/pyctcdecode/alphabet.py#L122 | ||
sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])} | ||
vocab_labels = list(sorted_dict.keys()) | ||
alphabet = Alphabet.build_alphabet(vocab_labels) | ||
|
||
# ii.) build languag model | ||
# different design options | ||
|
||
# 1) either: | ||
# --------------------- | ||
from pyctcdecode import AutoLanguageModel | ||
language_model = AutoLanguageModel.from_pretrained(...) | ||
# (this requires the following: | ||
# a. add `AutoLanguageModel` class in https://github.com/kensho-technologies/pyctcdecode/blob/main/pyctcdecode/language_model.py | ||
# b. add `.from_pretrained(...)` to `AutoLanguageModel` in kensho-technologies/pyctcdecode | ||
# => requires some work, but should be easy (need to discuss with pyctcdecode) | ||
|
||
# 2) or: | ||
# --------------------- | ||
from pyctcdecode import LanguageModel | ||
if self._is_ken_lm_model(pretrained_model_name_or_path): | ||
language_model = LanguageModel.load_from_hf_hub("...") | ||
elif self._is_hf_lm_model(pretrained_model_name_or_path): | ||
language_model = HfLanguageModel.load_from_hf_hub("...") | ||
# (this requires the followirg: | ||
# a. add `.from_pretrained(...)` class in kensho-technologies/pyctcdecode | ||
# => requires very little work and should be pretty easy (need to discuss with pyctcdecode) | ||
# b. (Future Work): add `HfLanguageModel` or `AutoLanguageModel` | ||
|
||
# 3) or: | ||
# --------------------- | ||
# do the whole model loading ourselves and create a `AutoLanguageModel` class in `transformers` | ||
# => requires fair amount of work but no need to discuss with pyctcdecode | ||
language_model = AutoLanguageModel.load_from_hf_hub("...") | ||
|
||
# iii.) Build ctc decoder | ||
# see: https://github.com/kensho-technologies/pyctcdecode/blob/94dfdae1d18ad95e799286173826aec2dec9a6b2/pyctcdecode/decoder.py#L181 | ||
ctc_decoder = BeamSearchDecoderCTC(alphabet, language_model) | ||
|
||
return ctc_decoder | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | ||
r""" | ||
Instantiate a :class:`~transformers.Wav2Vec2Processor` from a pretrained Wav2Vec2 processor. | ||
|
||
.. note:: | ||
|
||
This class method is simply calling Wav2Vec2FeatureExtractor's | ||
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` and | ||
Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. | ||
Please refer to the docstrings of the methods above for more information. | ||
|
||
Args: | ||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | ||
This can be either: | ||
|
||
- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on | ||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or | ||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. | ||
- a path to a `directory` containing a feature extractor file saved using the | ||
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g., | ||
``./my_model_directory/``. | ||
- a path or url to a saved feature extractor JSON `file`, e.g., | ||
``./my_model_directory/preprocessor_config.json``. | ||
**kwargs | ||
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and | ||
:class:`~transformers.PreTrainedTokenizer` | ||
""" | ||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
ctc_decoder = cls._load_ctc_decoder(pretrained_model_name_or_path, vocab_dict=tokenizer.get_vocab(), **kwargs) | ||
|
||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, ctc_decoder=ctc_decoder) | ||
|
||
def __call__(self, *args, **kwargs): | ||
""" | ||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's | ||
:meth:`~transformers.Wav2Vec2FeatureExtractor.__call__` and returns its output. If used in the context | ||
:meth:`~transformers.Wav2Vec2Processor.as_target_processor` this method forwards all its arguments to | ||
Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.__call__`. Please refer to the docstring of | ||
the above two methods for more information. | ||
""" | ||
return self.current_processor(*args, **kwargs) | ||
|
||
def pad(self, *args, **kwargs): | ||
""" | ||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's | ||
:meth:`~transformers.Wav2Vec2FeatureExtractor.pad` and returns its output. If used in the context | ||
:meth:`~transformers.Wav2Vec2Processor.as_target_processor` this method forwards all its arguments to | ||
Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.pad`. Please refer to the docstring of the | ||
above two methods for more information. | ||
""" | ||
return self.current_processor.pad(*args, **kwargs) | ||
|
||
def batch_decode(self, *args, **kwargs): | ||
""" | ||
# TODO (PVP): build switch so that both tokenizer and lm model can be used for decoding | ||
""" | ||
return self._batch_lm_decode(*args, **kwargs) | ||
|
||
def decode(self, *args, **kwargs): | ||
""" | ||
# TODO (PVP): build switch so that both tokenizer and lm model can be used for decoding | ||
""" | ||
return self._lm_decode(*args, **kwargs) | ||
|
||
def _batch_lm_decode(self, logits: Union[torch.FloatTensor, tf.Tensor, jnp.ndarray]): | ||
array_list = [array for array in logits.numpy()] | ||
""" | ||
logits are outputs of Wav2Vec2-like model | ||
**kwargs will be all arguments of https://github.com/kensho-technologies/pyctcdecode/blob/94dfdae1d18ad95e799286173826aec2dec9a6b2/pyctcdecode/decoder.py#L633 | ||
""" | ||
|
||
return self.ctc_decoder.decode_batch(array_list) | ||
|
||
def _lm_decode(self, logits: Union[torch.FloatTensor, tf.Tensor, jnp.ndarray], **kwargs): | ||
""" | ||
logits are outputs of Wav2Vec2-like model | ||
**kwargs will be all arguments of https://github.com/kensho-technologies/pyctcdecode/blob/94dfdae1d18ad95e799286173826aec2dec9a6b2/pyctcdecode/decoder.py#L600 | ||
""" | ||
return self.ctc_decoder.decode(logits.numpy()) | ||
|
||
@contextmanager | ||
def as_target_processor(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the processor now has 3 modes of operation, maybe this design can be deprecated in favor of |
||
""" | ||
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning | ||
Wav2Vec2. | ||
""" | ||
self.current_processor = self.tokenizer | ||
yield | ||
self.current_processor = self.feature_extractor |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decoder should also be saved here, I think.