Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate Wav2Vec2ForMaskedLM and add Wav2Vec2ForCTC #10089

Merged
merged 2 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/model_doc/wav2vec2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ Wav2Vec2ForMaskedLM

.. autoclass:: transformers.Wav2Vec2ForMaskedLM
:members: forward


Wav2Vec2ForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@
_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
Expand Down Expand Up @@ -1813,6 +1814,7 @@
)
from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2ForCTC",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
Expand All @@ -41,6 +42,7 @@
if is_torch_available():
from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import fairseq
import torch

from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, logging
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging


logging.set_verbosity_info()
Expand Down Expand Up @@ -141,7 +141,7 @@ def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_
"""
Copy/paste/tweak model's weights to transformers design.
"""
hf_wav2vec = Wav2Vec2ForMaskedLM(Wav2Vec2Config())
hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())

model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
Expand Down
85 changes: 84 additions & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" PyTorch Wav2Vec2 model. """


import warnings
from typing import Optional, Tuple

import torch
Expand All @@ -24,7 +25,7 @@

from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config
Expand Down Expand Up @@ -665,6 +666,10 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)

warnings.warn(
"The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning
)

self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
Expand All @@ -685,6 +690,10 @@ def forward(
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
TODO(PVP): Fill out when adding training

.. warning::

Wav2Vec2ForMaskedLM has been deprecated. Please use Wav2Vec2ForCTC instead.

Returns:

Example::
Expand Down Expand Up @@ -729,3 +738,77 @@ def forward(
return output

return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


@add_start_docstrings(
"""Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

self.init_weights()

@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
TODO(PVP): Fill out when adding training

Returns:

Example::

>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf

>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch

>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)

>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits

>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = tokenizer.decode(predicted_ids[0])
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.wav2vec2(
input_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)

if not return_dict:
output = (logits,) + outputs[1:]
return output

return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
5 changes: 5 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,6 +2229,11 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None


class Wav2Vec2ForCTC:
def __init__(self, *args, **kwargs):
requires_pytorch(self)


class Wav2Vec2ForMaskedLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
if is_torch_available():
import torch

from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer


class Wav2Vec2ModelTester:
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_model_from_pretrained(self):

@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
Expand Down Expand Up @@ -289,7 +289,7 @@ def map_to_array(batch):
return ds["speech"][:num_samples]

def test_inference_masked_lm_normal(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)

Expand All @@ -307,7 +307,7 @@ def test_inference_masked_lm_normal(self):
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)

def test_inference_masked_lm_normal_batched(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)

Expand All @@ -330,7 +330,7 @@ def test_inference_masked_lm_normal_batched(self):
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)

def test_inference_masked_lm_robust_batched(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)

input_speech = self._load_datasamples(4)
Expand Down