Skip to content

Commit

Permalink
Add TFSpeech2Text (#15113)
Browse files Browse the repository at this point in the history
* Add wrapper classes

* convert inner layers to tf

* Add TF Encoder and Decoder layers

* TFSpeech2Text models

* Loadable model

* TF model with same outputs as PT model

* test skeleton

* correct tests and run the fixup

* correct attention expansion

* TFSpeech2Text pask_key_values with TF format
  • Loading branch information
gante authored Feb 8, 2022
1 parent 6a5472a commit 8406fa6
Show file tree
Hide file tree
Showing 23 changed files with 2,499 additions and 96 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ Flax), PyTorch, and/or TensorFlow.
| SEW | | | | | |
| SEW-D | | | | | |
| Speech Encoder decoder | | | | | |
| Speech2Text | | | | | |
| Speech2Text | | | | | |
| Speech2Text2 | | | | | |
| Splinter | | | | | |
| SqueezeBERT | | | | | |
Expand Down
4 changes: 4 additions & 0 deletions docs/source/model_doc/auto.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its

[[autodoc]] TFAutoModelForVision2Seq

## TFAutoModelForSpeechSeq2Seq

[[autodoc]] TFAutoModelForSpeechSeq2Seq

## FlaxAutoModel

[[autodoc]] FlaxAutoModel
Expand Down
10 changes: 10 additions & 0 deletions docs/source/model_doc/speech_to_text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ See the [model hub](https://huggingface.co/models?filter=speech_to_text) to look
[[autodoc]] Speech2TextForConditionalGeneration
- forward
## TFSpeech2TextModel
[[autodoc]] TFSpeech2TextModel
- call
## TFSpeech2TextForConditionalGeneration
[[autodoc]] TFSpeech2TextForConditionalGeneration
- call
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,7 @@
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
Expand All @@ -1635,6 +1636,7 @@
"TFAutoModelForQuestionAnswering",
"TFAutoModelForSeq2SeqLM",
"TFAutoModelForSequenceClassification",
"TFAutoModelForSpeechSeq2Seq",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
Expand Down Expand Up @@ -1946,6 +1948,14 @@
"TFRoFormerPreTrainedModel",
]
)
_import_structure["models.speech_to_text"].extend(
[
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSpeech2TextForConditionalGeneration",
"TFSpeech2TextModel",
"TFSpeech2TextPreTrainedModel",
]
)
_import_structure["models.t5"].extend(
[
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -3588,6 +3598,7 @@
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
Expand All @@ -3602,6 +3613,7 @@
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForSpeechSeq2Seq,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
Expand Down Expand Up @@ -3850,6 +3862,12 @@
TFRoFormerModel,
TFRoFormerPreTrainedModel,
)
from .models.speech_to_text import (
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSpeech2TextForConditionalGeneration,
TFSpeech2TextModel,
TFSpeech2TextPreTrainedModel,
)
from .models.t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
Expand Down
74 changes: 35 additions & 39 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,12 @@ def generate(
Parameters:
input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
The sequence used as a prompt for the generation. If `None` the method initializes it with
`bos_token_id` and a batch size of 1.
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length,
feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 10):
Expand Down Expand Up @@ -657,11 +660,12 @@ def generate(
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
elif attention_mask is None:
attention_mask = tf.ones_like(input_ids)
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys())
if accepts_attention_mask:
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
elif attention_mask is None:
attention_mask = tf.ones(shape_list(input_ids)[:2], dtype=tf.int32)

if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
Expand Down Expand Up @@ -697,16 +701,12 @@ def generate(
encoder = self.get_encoder()

encoder_kwargs = {
"attention_mask": attention_mask,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict_in_generate,
}

# vision models don't use `attention_mask`.
signature = dict(inspect.signature(encoder.call).parameters)
if "attention_mask" not in signature:
encoder_kwargs.pop("attention_mask")
if accepts_attention_mask:
encoder_kwargs["attention_mask"] = attention_mask

encoder_outputs = encoder(input_ids, **encoder_kwargs)
if return_dict_in_generate:
Expand All @@ -715,23 +715,15 @@ def generate(
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states

# The condition `len(shape_list(input_ids)) == 2` is to make this block treats only text inputs.
# (vision inputs might occur when the model is an encoder-decoder model)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1):
input_ids_len = shape_list(input_ids)[-1]
input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
)
attention_mask = tf.broadcast_to(
tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
)
input_ids = tf.reshape(
input_ids, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = tf.reshape(
attention_mask, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
expanded_batch_idxs = tf.reshape(
tf.repeat(tf.expand_dims(tf.range(batch_size), -1), repeats=num_beams * effective_batch_mult, axis=1),
shape=(-1,),
)
# prepares text-based inputs
if len(shape_list(input_ids)) == 2:
input_ids = tf.gather(input_ids, expanded_batch_idxs, axis=0)
if accepts_attention_mask:
attention_mask = tf.gather(attention_mask, expanded_batch_idxs, axis=0)

if self.config.is_encoder_decoder:

Expand All @@ -749,11 +741,6 @@ def generate(
batch_size == encoder_outputs[0].shape[0]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "

# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = tf.reshape(
tf.repeat(tf.expand_dims(tf.range(batch_size), -1), repeats=num_beams * effective_batch_mult, axis=1),
shape=(-1,),
)
# expand encoder_outputs
encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0),)
else:
Expand Down Expand Up @@ -851,7 +838,8 @@ def _generate_no_beam_search(
unfinished_sents = tf.ones_like(input_ids[:, 0])
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length

past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
# defined for encoder-decoder models, None for decoder-only models
past = encoder_outputs

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
Expand All @@ -871,7 +859,11 @@ def _generate_no_beam_search(

while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
input_ids,
past=past,
attention_mask=attention_mask,
use_cache=use_cache,
**kwargs,
)
outputs = self(
**model_inputs,
Expand Down Expand Up @@ -1132,7 +1124,11 @@ def _generate_beam_search(

while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
input_ids,
past=past,
attention_mask=attention_mask,
use_cache=use_cache,
**kwargs,
)
outputs = self(
**model_inputs,
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TransposeType(ExplicitEnum):

NO = "no"
SIMPLE = "simple"
CONV1D = "conv1d"
CONV2D = "conv2d"


Expand Down Expand Up @@ -68,8 +69,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",

# When should we transpose the weights
if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
# A simple heuristic to detect conv layer using weight array shape
transpose = TransposeType.CONV2D
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 3:
transpose = TransposeType.CONV1D
elif bool(
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
or "emb_projs" in tf_name
Expand Down Expand Up @@ -194,7 +196,6 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
# authorized missing keys don't have to be loaded
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
continue

raise AttributeError(f"{name} not found in PyTorch model")

array = pt_state_dict[name].numpy()
Expand All @@ -204,6 +205,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
array = numpy.transpose(array, axes=(2, 3, 1, 0))
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# PT: (num_out_channel, num_in_channel, kernel)
# -> TF: (kernel, num_in_channel, num_out_channel)
array = numpy.transpose(array, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array)

Expand Down Expand Up @@ -355,7 +361,6 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
all_tf_weights = set(list(tf_weights_map.keys()))
loaded_pt_weights_data_ptr = {}
missing_keys_pt = []

for pt_weight_name, pt_weight in current_pt_params_dict.items():
# Handle PyTorch shared weight ()not duplicated in TF 2.0
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
Expand All @@ -377,6 +382,11 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
array = numpy.transpose(array, axes=(3, 2, 0, 1))
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# TF: (kernel, num_in_channel, num_out_channel)
# -> PT: (num_out_channel, num_in_channel, kernel)
array = numpy.transpose(array, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array)

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
Expand All @@ -101,6 +102,7 @@
"TFAutoModelForQuestionAnswering",
"TFAutoModelForSeq2SeqLM",
"TFAutoModelForSequenceClassification",
"TFAutoModelForSpeechSeq2Seq",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
Expand Down Expand Up @@ -201,6 +203,7 @@
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
Expand All @@ -215,6 +218,7 @@
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForSpeechSeq2Seq,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):


AutoModelForSpeechSeq2Seq = auto_class_update(
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)


Expand Down
20 changes: 20 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("speech_to_text", "TFSpeech2TextModel"),
("clip", "TFCLIPModel"),
("deberta-v2", "TFDebertaV2Model"),
("deberta", "TFDebertaModel"),
Expand Down Expand Up @@ -103,6 +104,7 @@
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
Expand Down Expand Up @@ -204,6 +206,12 @@
]
)

TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
]
)

TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
Expand Down Expand Up @@ -340,6 +348,9 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
Expand Down Expand Up @@ -468,6 +479,15 @@ class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
)


class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


TFAutoModelForSpeechSeq2Seq = auto_class_update(
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)


class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
@classmethod
def from_config(cls, config):
Expand Down
Loading

0 comments on commit 8406fa6

Please sign in to comment.