Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speech] Refactor Examples #14040

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/model_doc/sew.rst
Expand Up @@ -59,3 +59,9 @@ SEWForCTC
.. autoclass:: transformers.SEWForCTC
:members: forward


SEWForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.SEWForSequenceClassification
:members: forward
5 changes: 5 additions & 0 deletions docs/source/model_doc/sew_d.rst
Expand Up @@ -59,3 +59,8 @@ SEWDForCTC
.. autoclass:: transformers.SEWDForCTC
:members: forward

SEWDForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.SEWDForSequenceClassification
:members: forward
18 changes: 16 additions & 2 deletions src/transformers/__init__.py
Expand Up @@ -1143,6 +1143,7 @@
[
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
"SEWForCTC",
"SEWForSequenceClassification",
"SEWModel",
"SEWPreTrainedModel",
]
Expand All @@ -1151,6 +1152,7 @@
[
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
"SEWDForCTC",
"SEWDForSequenceClassification",
"SEWDModel",
"SEWDPreTrainedModel",
]
Expand Down Expand Up @@ -2858,8 +2860,20 @@
RoFormerPreTrainedModel,
load_tf_weights_in_roformer,
)
from .models.sew import SEW_PRETRAINED_MODEL_ARCHIVE_LIST, SEWForCTC, SEWModel, SEWPreTrainedModel
from .models.sew_d import SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST, SEWDForCTC, SEWDModel, SEWDPreTrainedModel
from .models.sew import (
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
SEWForCTC,
SEWForSequenceClassification,
SEWModel,
SEWPreTrainedModel,
)
from .models.sew_d import (
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST,
SEWDForCTC,
SEWDForSequenceClassification,
SEWDModel,
SEWDPreTrainedModel,
)
from .models.speech_encoder_decoder import SpeechEncoderDecoderModel
from .models.speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Expand Up @@ -476,6 +476,8 @@
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("hubert", "HubertForSequenceClassification"),
("sew", "SEWForSequenceClassification"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to how all BERT heads (ForQA, ForSequenceClass, ForMC, ...) are added to all BERT-like models for easy comparison and added functionality, all speech models should have the superb heads.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agreed!

("sew-d", "SEWDForSequenceClassification"),
]
)

Expand Down
98 changes: 31 additions & 67 deletions src/transformers/models/hubert/modeling_hubert.py
Expand Up @@ -25,7 +25,12 @@
from transformers.deepspeed import is_deepspeed_zero3_enabled

from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
Expand All @@ -36,6 +41,13 @@

_CONFIG_FOR_DOC = "HubertConfig"
_CHECKPOINT_FOR_DOC = "facebook/hubert-base-ls960"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"

_SEQ_CLASS_CHECKPOINT = ("superb/hubert-base-superb-ks",)
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"

_HIDDEN_STATES_START_POSITION = 1


HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/hubert-base-ls960",
Expand Down Expand Up @@ -999,6 +1011,7 @@ def forward(
"""Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
HUBERT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
class HubertForCTC(HubertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
Expand All @@ -1025,7 +1038,12 @@ def freeze_feature_extractor(self):
self.hubert.feature_extractor._freeze_parameters()

@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_values,
Expand All @@ -1041,41 +1059,6 @@ def forward(
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
config.vocab_size - 1]``.

Returns:

Example::

>>> import torch
>>> from transformers import Wav2Vec2Processor, HubertForCTC
>>> from datasets import load_dataset
>>> import soundfile as sf

>>> processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
>>> model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")

>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch

>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)

>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)

>>> transcription = processor.decode(predicted_ids[0])

>>> # compute loss
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"

>>> # wrap processor as target processor to encode labels
>>> with processor.as_target_processor():
... labels = processor(target_transcription, return_tensors="pt").input_ids

>>> loss = model(input_values, labels=labels).loss
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -1126,7 +1109,7 @@ def forward(
)

if not return_dict:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
Comment on lines -1129 to +1112
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is way simpler to understand! We should do something like that for BERT & friends too

return ((loss,) + output) if loss is not None else output

return CausalLMOutput(
Expand All @@ -1141,8 +1124,8 @@ def forward(
""",
HUBERT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that it works now!

class HubertForSequenceClassification(HubertPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert
def __init__(self, config):
super().__init__(config)

Expand All @@ -1155,15 +1138,13 @@ def __init__(self, config):

self.init_weights()

# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.hubert.feature_extractor._freeze_parameters()

# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
Expand All @@ -1173,7 +1154,13 @@ def freeze_base_model(self):
param.requires_grad = False

@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
Expand All @@ -1188,29 +1175,6 @@ def forward(
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

Returns:

Example::

>>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
>>> from datasets import load_dataset

>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
>>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")

>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")

>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_class_ids = torch.argmax(logits, dim=-1)

>>> # compute loss
>>> target_label = "down"
>>> labels = torch.tensor([model.config.label2id[target_label]])

>>> loss = model(input_values, labels=labels).loss
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand All @@ -1225,7 +1189,7 @@ def forward(
)

if self.config.use_weighted_layer_sum:
hidden_states = outputs[1]
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
Expand All @@ -1248,7 +1212,7 @@ def forward(
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/sew/__init__.py
Expand Up @@ -28,6 +28,7 @@
_import_structure["modeling_sew"] = [
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
"SEWForCTC",
"SEWForSequenceClassification",
"SEWModel",
"SEWPreTrainedModel",
]
Expand All @@ -36,7 +37,13 @@
from .configuration_sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig

if is_torch_available():
from .modeling_sew import SEW_PRETRAINED_MODEL_ARCHIVE_LIST, SEWForCTC, SEWModel, SEWPreTrainedModel
from .modeling_sew import (
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
SEWForCTC,
SEWForSequenceClassification,
SEWModel,
SEWPreTrainedModel,
)


else:
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/sew/configuration_sew.py
Expand Up @@ -113,6 +113,11 @@ class SEWConfig(PretrainedConfig):
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
instance of :class:`~transformers.SEWForCTC`.
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.

Example::

Expand Down Expand Up @@ -161,6 +166,8 @@ def __init__(
mask_feature_length=10,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
Expand Down Expand Up @@ -214,3 +221,7 @@ def __init__(
# ctc loss
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity

# sequence classification
self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size