From b6f332ecaf18054109294dd2efa1a5e6aa274a03 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Fri, 27 Aug 2021 20:52:51 +0300 Subject: [PATCH] Add Wav2Vec2 & Hubert ForSequenceClassification (#13153) * Add hubert classifier + tests * Add hubert classifier + tests * Dummies for all classification tests * Wav2Vec2 classifier + ER test * Fix hubert integration tests * Add hubert IC * Pass tests for all classification tasks on Hubert * Pass all tests + copies * Move models to the SUPERB org --- docs/source/model_doc/hubert.rst | 8 + docs/source/model_doc/wav2vec2.rst | 8 + src/transformers/__init__.py | 4 + src/transformers/models/hubert/__init__.py | 2 + .../models/hubert/configuration_hubert.py | 9 + ...rt_original_s3prl_checkpoint_to_pytorch.py | 69 +++++++ .../models/hubert/modeling_hubert.py | 155 +++++++++++++-- src/transformers/models/wav2vec2/__init__.py | 2 + .../models/wav2vec2/configuration_wav2vec2.py | 9 + ...c2_original_s3prl_checkpoint_to_pytorch.py | 69 +++++++ .../wav2vec2/feature_extraction_wav2vec2.py | 6 +- .../models/wav2vec2/modeling_wav2vec2.py | 128 +++++++++++- src/transformers/utils/dummy_pt_objects.py | 18 ++ tests/test_modeling_hubert.py | 185 +++++++++++++++++- tests/test_modeling_wav2vec2.py | 185 +++++++++++++++++- utils/check_repo.py | 2 + 16 files changed, 823 insertions(+), 36 deletions(-) create mode 100644 src/transformers/models/hubert/convert_hubert_original_s3prl_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py diff --git a/docs/source/model_doc/hubert.rst b/docs/source/model_doc/hubert.rst index 4e0bdcca326cd..0df42fdef5259 100644 --- a/docs/source/model_doc/hubert.rst +++ b/docs/source/model_doc/hubert.rst @@ -64,6 +64,14 @@ HubertForCTC .. autoclass:: transformers.HubertForCTC :members: forward + +HubertForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.HubertForSequenceClassification + :members: forward + + TFHubertModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index e96eb80329704..2aef0abb86a0e 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -96,6 +96,14 @@ Wav2Vec2ForCTC .. autoclass:: transformers.Wav2Vec2ForCTC :members: forward + +Wav2Vec2ForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Wav2Vec2ForSequenceClassification + :members: forward + + Wav2Vec2ForPreTraining ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c5e22b1c718d5..421ffd06b9d3a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -818,6 +818,7 @@ [ "HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "HubertForCTC", + "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel", ] @@ -1128,6 +1129,7 @@ "Wav2Vec2ForCTC", "Wav2Vec2ForMaskedLM", "Wav2Vec2ForPreTraining", + "Wav2Vec2ForSequenceClassification", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", ] @@ -2424,6 +2426,7 @@ from .models.hubert import ( HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, HubertForCTC, + HubertForSequenceClassification, HubertModel, HubertPreTrainedModel, ) @@ -2681,6 +2684,7 @@ Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, + Wav2Vec2ForSequenceClassification, Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) diff --git a/src/transformers/models/hubert/__init__.py b/src/transformers/models/hubert/__init__.py index 4f7d8f8facd2e..f62cc14bd76d9 100644 --- a/src/transformers/models/hubert/__init__.py +++ b/src/transformers/models/hubert/__init__.py @@ -28,6 +28,7 @@ _import_structure["modeling_hubert"] = [ "HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "HubertForCTC", + "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel", ] @@ -48,6 +49,7 @@ from .modeling_hubert import ( HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, HubertForCTC, + HubertForSequenceClassification, HubertModel, HubertPreTrainedModel, ) diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py index f3d2f77ed0290..633807684fcca 100644 --- a/src/transformers/models/hubert/configuration_hubert.py +++ b/src/transformers/models/hubert/configuration_hubert.py @@ -115,6 +115,11 @@ class HubertConfig(PretrainedConfig): Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance of :class:`~transformers.HubertForCTC`. + use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of :class:`~transformers.HubertForSequenceClassification`. + classifier_proj_size (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): If True, use gradient checkpointing to save memory at the expense of slower backward pass. @@ -165,6 +170,8 @@ def __init__( mask_feature_length=10, ctc_loss_reduction="sum", ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, @@ -197,6 +204,8 @@ def __init__( self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm self.gradient_checkpointing = gradient_checkpointing + self.use_weighted_layer_sum = use_weighted_layer_sum + self.classifier_proj_size = classifier_proj_size if ( (len(self.conv_stride) != self.num_feat_extract_layers) diff --git a/src/transformers/models/hubert/convert_hubert_original_s3prl_checkpoint_to_pytorch.py b/src/transformers/models/hubert/convert_hubert_original_s3prl_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000..51908f930242c --- /dev/null +++ b/src/transformers/models/hubert/convert_hubert_original_s3prl_checkpoint_to_pytorch.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Hubert checkpoint.""" + + +import argparse + +import torch + +from transformers import HubertConfig, HubertForSequenceClassification, Wav2Vec2FeatureExtractor, logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +SUPPORTED_MODELS = ["UtteranceLevel"] + + +@torch.no_grad() +def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path): + """ + Copy/paste/tweak model's weights to transformers design. + """ + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS: + raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}") + + downstream_dict = checkpoint["Downstream"] + + hf_congfig = HubertConfig.from_pretrained(config_path) + hf_model = HubertForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig) + hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + base_model_name, return_attention_mask=True, do_normalize=False + ) + + if hf_congfig.use_weighted_layer_sum: + hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] + + hf_model.projector.weight.data = downstream_dict["projector.weight"] + hf_model.projector.bias.data = downstream_dict["projector.bias"] + hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"] + hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"] + + hf_feature_extractor.save_pretrained(model_dump_path) + hf_model.save_pretrained(model_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model." + ) + parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.") + parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.") + args = parser.parse_args() + convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 910baf1c9f06e..012cd774dafda 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -20,12 +20,13 @@ import torch import torch.utils.checkpoint from torch import nn +from torch.nn import CrossEntropyLoss from transformers.deepspeed import is_deepspeed_zero3_enabled from ...activations import ACT2FN from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings -from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_hubert import HubertConfig @@ -735,6 +736,18 @@ def _conv_out_length(input_length, kernel_size, stride): return input_lengths + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + HUBERT_START_DOCSTRING = r""" Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units @@ -904,19 +917,8 @@ def forward( extract_features = extract_features.transpose(1, 2) if attention_mask is not None: - # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) - - attention_mask = torch.zeros( - extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device - ) - - # these two operations makes sure that all values - # before the output lengths indices are attended to - attention_mask[ - (torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1) - ] = 1 - attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) hidden_states = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) @@ -1070,3 +1072,128 @@ def forward( return CausalLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) + + +@add_start_docstrings( + """ + Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + HUBERT_START_DOCSTRING, +) +class HubertForSequenceClassification(HubertPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert + def __init__(self, config): + super().__init__(config) + + self.hubert = HubertModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + self.init_weights() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.hubert.feature_extractor._freeze_parameters() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.hubert.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Example:: + + >>> import torch + >>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification + >>> from datasets import load_dataset + + >>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks") + >>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks") + + >>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test") + + >>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1 + >>> logits = model(input_values).logits + >>> predicted_class_ids = torch.argmax(logits, dim=-1) + + >>> # compute loss + >>> target_label = "down" + >>> labels = torch.tensor([model.config.label2id[target_label]]) + + >>> loss = model(input_values, labels=labels).loss + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.hubert( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[1] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + hidden_states[~padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index ae6a2a7931430..445e918303402 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -33,6 +33,7 @@ "Wav2Vec2ForCTC", "Wav2Vec2ForMaskedLM", "Wav2Vec2ForPreTraining", + "Wav2Vec2ForSequenceClassification", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", ] @@ -66,6 +67,7 @@ Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, + Wav2Vec2ForSequenceClassification, Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index 88200133d5404..6df4a87064bf0 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -133,6 +133,11 @@ class Wav2Vec2Config(PretrainedConfig): Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance of :class:`~transformers.Wav2Vec2ForCTC`. + use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`. + classifier_proj_size (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): If True, use gradient checkpointing to save memory at the expense of slower backward pass. @@ -191,6 +196,8 @@ def __init__( diversity_loss_weight=0.1, ctc_loss_reduction="sum", ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, @@ -223,6 +230,8 @@ def __init__( self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm self.gradient_checkpointing = gradient_checkpointing + self.use_weighted_layer_sum = use_weighted_layer_sum + self.classifier_proj_size = classifier_proj_size if ( (len(self.conv_stride) != self.num_feat_extract_layers) diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000..bd7a7370cfcde --- /dev/null +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Hubert checkpoint.""" + + +import argparse + +import torch + +from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +SUPPORTED_MODELS = ["UtteranceLevel"] + + +@torch.no_grad() +def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path): + """ + Copy/paste/tweak model's weights to transformers design. + """ + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS: + raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}") + + downstream_dict = checkpoint["Downstream"] + + hf_congfig = Wav2Vec2Config.from_pretrained(config_path) + hf_model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig) + hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + base_model_name, return_attention_mask=True, do_normalize=False + ) + + if hf_congfig.use_weighted_layer_sum: + hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] + + hf_model.projector.weight.data = downstream_dict["projector.weight"] + hf_model.projector.bias.data = downstream_dict["projector.bias"] + hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"] + hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"] + + hf_feature_extractor.save_pretrained(model_dump_path) + hf_model.save_pretrained(model_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model." + ) + parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.") + parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.") + args = parser.parse_args() + convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path) diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index 805e54ae9a605..01c6966637d6d 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -83,9 +83,6 @@ def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[ """ Every array in the list is normalized to have zero mean and unit variance """ - if isinstance(input_values[0], np.ndarray): - input_values = [x.astype(np.float32) for x in input_values] - normed_input_values = [ (x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths) ] @@ -205,6 +202,9 @@ def __call__( padded_input_values = padded_inputs["input_values"] input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])] + if isinstance(padded_inputs["input_values"][0], np.ndarray): + padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]] + # zero-mean and unit-variance normalization if self.do_normalize: padded_inputs["input_values"] = self.zero_mean_unit_var_norm( diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 9454f5b00567a..7db5fd7f1d427 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -22,6 +22,7 @@ import torch import torch.utils.checkpoint from torch import nn +from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled @@ -31,7 +32,7 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_wav2vec2 import Wav2Vec2Config @@ -1057,7 +1058,7 @@ def forward( extract_features = extract_features.transpose(1, 2) if attention_mask is not None: - # compute reduced attention_mask correponding to feature vectors + # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) hidden_states, extract_features = self.feature_projection(extract_features) @@ -1527,3 +1528,126 @@ def forward( return CausalLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) + + +@add_start_docstrings( + """ + Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + WAV_2_VEC_2_START_DOCSTRING, +) +class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wav2vec2 = Wav2Vec2Model(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Example:: + + >>> import torch + >>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification + >>> from datasets import load_dataset + + >>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks") + >>> model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks") + + >>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test") + + >>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1 + >>> logits = model(input_values).logits + >>> predicted_class_ids = torch.argmax(logits, dim=-1) + + >>> # compute loss + >>> target_label = "down" + >>> labels = torch.tensor([model.config.label2id[target_label]]) + + >>> loss = model(input_values, labels=labels).loss + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # End copy + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[2] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + hidden_states[~padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 58da53ce4181c..630456c80b0f5 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1863,6 +1863,15 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class HubertForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HubertModel: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -3473,6 +3482,15 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Wav2Vec2ForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Wav2Vec2Model: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_hubert.py b/tests/test_modeling_hubert.py index c42014d6a702a..e79442646d0d7 100644 --- a/tests/test_modeling_hubert.py +++ b/tests/test_modeling_hubert.py @@ -31,7 +31,13 @@ if is_torch_available(): import torch - from transformers import HubertForCTC, HubertModel, Wav2Vec2Processor + from transformers import ( + HubertForCTC, + HubertForSequenceClassification, + HubertModel, + Wav2Vec2FeatureExtractor, + Wav2Vec2Processor, + ) from transformers.models.hubert.modeling_hubert import _compute_mask_indices @@ -187,7 +193,32 @@ def check_ctc_loss(self, config, input_values, *args): self.parent.assertTrue(isinstance(sum_loss, float)) self.parent.assertTrue(isinstance(mean_loss, float)) - def check_training(self, config, input_values, *args): + def check_seq_classifier_loss(self, config, input_values, *args): + model = HubertForSequenceClassification(config=config) + model.to(torch_device) + + # make sure that dropout is disabled + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 + + masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + unmasked_loss = model(input_values, labels=labels).loss.item() + + self.parent.assertTrue(isinstance(masked_loss, float)) + self.parent.assertTrue(isinstance(unmasked_loss, float)) + self.parent.assertTrue(masked_loss != unmasked_loss) + + def check_ctc_training(self, config, input_values, *args): config.ctc_zero_infinity = True model = HubertForCTC(config=config) model.to(torch_device) @@ -216,6 +247,29 @@ def check_training(self, config, input_values, *args): loss.backward() + def check_seq_classifier_training(self, config, input_values, *args): + config.ctc_zero_infinity = True + model = HubertForSequenceClassification(config=config) + model.to(torch_device) + model.train() + + # freeze everything but the classification head + model.freeze_base_model() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + def check_labels_out_of_vocab(self, config, input_values, *args): model = HubertForCTC(config) model.to(torch_device) @@ -238,7 +292,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class HubertModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else () + all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -258,9 +312,17 @@ def test_ctc_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) - def test_train(self): + def test_seq_classifier_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.check_training(*config_and_inputs) + self.model_tester.check_seq_classifier_loss(*config_and_inputs) + + def test_ctc_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_training(*config_and_inputs) + + def test_seq_classifier_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_seq_classifier_training(*config_and_inputs) def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -371,7 +433,7 @@ def test_model_from_pretrained(self): @require_torch class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else () + all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -397,9 +459,17 @@ def test_ctc_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) - def test_train(self): + def test_seq_classifier_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_seq_classifier_loss(*config_and_inputs) + + def test_ctc_train(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.check_training(*config_and_inputs) + self.model_tester.check_ctc_training(*config_and_inputs) + + def test_seq_classifier_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_seq_classifier_training(*config_and_inputs) def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -557,6 +627,13 @@ def map_to_array(batch): return ds["speech"][:num_samples] + def _load_superb(self, task, num_samples): + from datasets import load_dataset + + ds = load_dataset("anton-l/superb_dummy", task, split="test") + + return ds[:num_samples] + def test_inference_ctc_batched(self): model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(torch_device) processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True) @@ -579,3 +656,95 @@ def test_inference_ctc_batched(self): "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore", ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) + + def test_inference_keyword_spotting(self): + model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks") + input_data = self._load_superb("ks", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1) + + expected_labels = [2, 6, 10, 9] + # s3prl logits for the same batch + expected_logits = torch.tensor([7.6692, 17.7795, 11.1562, 11.8232], device=torch_device) + + self.assertListEqual(predicted_ids.tolist(), expected_labels) + self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2)) + + def test_inference_intent_classification(self): + model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ic").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ic") + input_data = self._load_superb("ic", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + + predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1) + predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1) + predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1) + + expected_labels_action = [1, 0, 4, 3] + expected_logits_action = torch.tensor([5.9052, 12.5865, 4.4840, 10.0240], device=torch_device) + expected_labels_object = [1, 10, 3, 4] + expected_logits_object = torch.tensor([5.5316, 11.7946, 8.1672, 23.2415], device=torch_device) + expected_labels_location = [0, 0, 0, 1] + expected_logits_location = torch.tensor([5.2053, 8.9577, 10.0447, 8.1481], device=torch_device) + + self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action) + self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object) + self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location) + + # TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572 + self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=3e-1)) + self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=3e-1)) + self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=3e-1)) + + def test_inference_speaker_identification(self): + model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-sid").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-sid") + input_data = self._load_superb("si", 4) + + output_logits = [] + with torch.no_grad(): + for example in input_data["speech"]: + input = processor(example, return_tensors="pt", padding=True) + output = model(input.input_values.to(torch_device), attention_mask=None) + output_logits.append(output.logits[0]) + output_logits = torch.stack(output_logits) + predicted_logits, predicted_ids = torch.max(output_logits, dim=-1) + + expected_labels = [5, 1, 1, 3] + # s3prl logits for the same batch + expected_logits = torch.tensor([78231.5547, 123166.6094, 122785.4141, 84851.2969], device=torch_device) + + self.assertListEqual(predicted_ids.tolist(), expected_labels) + # TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572 + self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=10)) + + def test_inference_emotion_recognition(self): + model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-er").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-er") + input_data = self._load_superb("er", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1) + + expected_labels = [1, 1, 2, 2] + # s3prl logits for the same batch + expected_logits = torch.tensor([2.8384, 2.3389, 3.8564, 4.5558], device=torch_device) + + self.assertListEqual(predicted_ids.tolist(), expected_labels) + # TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572 + self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-1)) diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 8206797a9983d..ae12717666e6e 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -14,7 +14,6 @@ # limitations under the License. """ Testing suite for the PyTorch Wav2Vec2 model. """ - import math import unittest @@ -36,6 +35,7 @@ Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, + Wav2Vec2ForSequenceClassification, Wav2Vec2Model, Wav2Vec2Processor, ) @@ -194,7 +194,32 @@ def check_ctc_loss(self, config, input_values, *args): self.parent.assertTrue(isinstance(sum_loss, float)) self.parent.assertTrue(isinstance(mean_loss, float)) - def check_training(self, config, input_values, *args): + def check_seq_classifier_loss(self, config, input_values, *args): + model = Wav2Vec2ForSequenceClassification(config=config) + model.to(torch_device) + + # make sure that dropout is disabled + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 + + masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() + unmasked_loss = model(input_values, labels=labels).loss.item() + + self.parent.assertTrue(isinstance(masked_loss, float)) + self.parent.assertTrue(isinstance(unmasked_loss, float)) + self.parent.assertTrue(masked_loss != unmasked_loss) + + def check_ctc_training(self, config, input_values, *args): config.ctc_zero_infinity = True model = Wav2Vec2ForCTC(config=config) model.to(torch_device) @@ -223,6 +248,29 @@ def check_training(self, config, input_values, *args): loss.backward() + def check_seq_classifier_training(self, config, input_values, *args): + config.ctc_zero_infinity = True + model = Wav2Vec2ForSequenceClassification(config=config) + model.to(torch_device) + model.train() + + # freeze everything but the classification head + model.freeze_base_model() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + def check_labels_out_of_vocab(self, config, input_values, *args): model = Wav2Vec2ForCTC(config) model.to(torch_device) @@ -246,7 +294,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( - (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else () + (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining) + if is_torch_available() + else () ) test_pruning = False test_headmasking = False @@ -267,9 +317,17 @@ def test_ctc_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) - def test_train(self): + def test_seq_classifier_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.check_training(*config_and_inputs) + self.model_tester.check_seq_classifier_loss(*config_and_inputs) + + def test_ctc_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_training(*config_and_inputs) + + def test_seq_classifier_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_seq_classifier_training(*config_and_inputs) def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -384,7 +442,9 @@ def test_model_from_pretrained(self): @require_torch class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( - (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else () + (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining) + if is_torch_available() + else () ) test_pruning = False test_headmasking = False @@ -411,9 +471,17 @@ def test_ctc_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) - def test_train(self): + def test_seq_classifier_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_seq_classifier_loss(*config_and_inputs) + + def test_ctc_train(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.check_training(*config_and_inputs) + self.model_tester.check_ctc_training(*config_and_inputs) + + def test_seq_classifier_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_seq_classifier_training(*config_and_inputs) def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -691,6 +759,13 @@ def map_to_array(batch): return ds["speech"][:num_samples] + def _load_superb(self, task, num_samples): + from datasets import load_dataset + + ds = load_dataset("anton-l/superb_dummy", task, split="test") + + return ds[:num_samples] + def test_inference_ctc_normal(self): model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") model.to(torch_device) @@ -795,7 +870,10 @@ def test_inference_integration(self): # fmt: off expected_cosine_sim_masked = torch.tensor( - [0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997], + [0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, + 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, + 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, + 0.6997], device=torch_device, ) # fmt: on @@ -913,3 +991,92 @@ def test_loss_pretraining(self): expected_loss = 62.5170 self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3) + + def test_inference_keyword_spotting(self): + model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks") + input_data = self._load_superb("ks", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1) + + expected_labels = [7, 6, 10, 9] + # s3prl logits for the same batch + expected_logits = torch.tensor([6.1186, 11.8961, 10.2931, 6.0898], device=torch_device) + + self.assertListEqual(predicted_ids.tolist(), expected_labels) + self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2)) + + def test_inference_intent_classification(self): + model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic") + input_data = self._load_superb("ic", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + + predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1) + predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1) + predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1) + + expected_labels_action = [0, 0, 2, 3] + expected_logits_action = torch.tensor([0.4568, 11.0848, 1.6621, 9.3841], device=torch_device) + expected_labels_object = [3, 10, 3, 4] + expected_logits_object = torch.tensor([1.5322, 10.7094, 5.2469, 22.1318], device=torch_device) + expected_labels_location = [0, 0, 0, 1] + expected_logits_location = torch.tensor([1.5335, 6.5096, 10.5704, 11.0569], device=torch_device) + + self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action) + self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object) + self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location) + + self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=1e-2)) + self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=1e-2)) + self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=1e-2)) + + def test_inference_speaker_identification(self): + model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid") + input_data = self._load_superb("si", 4) + + output_logits = [] + with torch.no_grad(): + for example in input_data["speech"]: + input = processor(example, return_tensors="pt", padding=True) + output = model(input.input_values.to(torch_device), attention_mask=None) + output_logits.append(output.logits[0]) + output_logits = torch.stack(output_logits) + predicted_logits, predicted_ids = torch.max(output_logits, dim=-1) + + expected_labels = [251, 1, 1, 3] + # s3prl logits for the same batch + expected_logits = torch.tensor([37.5627, 71.6362, 64.2419, 31.7778], device=torch_device) + + self.assertListEqual(predicted_ids.tolist(), expected_labels) + self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2)) + + def test_inference_emotion_recognition(self): + model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er") + input_data = self._load_superb("er", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1) + + expected_labels = [1, 1, 2, 2] + # s3prl logits for the same batch + expected_logits = torch.tensor([2.1722, 3.0779, 8.0287, 6.6797], device=torch_device) + + self.assertListEqual(predicted_ids.tolist(), expected_labels) + self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 068efc0b15b51..088d760aa9b7f 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -122,6 +122,8 @@ "TFRagTokenForGeneration", "Wav2Vec2ForCTC", "HubertForCTC", + "Wav2Vec2ForSequenceClassification", + "HubertForSequenceClassification", "XLMForQuestionAnswering", "XLNetForQuestionAnswering", "SeparableConv1D",