Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFSpeech2Text #15113

Merged
merged 47 commits into from
Feb 8, 2022
Merged

Add TFSpeech2Text #15113

merged 47 commits into from
Feb 8, 2022

Conversation

gante
Copy link
Member

@gante gante commented Jan 11, 2022

What does this PR do?

This PR adds a TF port of Speech2Text. A summary of the changes:

  • This model borrows a lot of code from TFBart, just like Speech2Text borrowed from Bart;
  • Tried to follow the changes in other PRs to enable smooth interoperation with other parts of transformers (e.g. auto classes), might be missing a few things 👼 ;
  • This seems to be the first TF model with speech as input, so I had to touch common TF code to enable correct data piping and misc operations (e.g. enable loading Conv1D PT weights into TF);
  • Likewise, there were a few tests in test_modelling_tf_common.py that didn't quite fit this new kind of model.

TODO:

  • create TF version of the weights, so we can load a model without from_pt=True

@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jan 13, 2022
@gante gante marked this pull request as ready for review January 28, 2022 20:18
@gante gante added TensorFlow Anything TensorFlow and removed WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress labels Jan 28, 2022
@gante
Copy link
Member Author

gante commented Jan 28, 2022

Tagging @sgugger as a core dev, @patil-suraj as a core dev + original creator of our Speech2Text, and @Rocketknight1 as the TensorFlow boy. Feel free to redirect the reviews if you know of better people to review.

Some pipeline tests are failing almost surely due to the changes in generation_tf_utils.py, as some models expect (encoder_outputs, past) in the past variable and others don't -- having a look, but open to suggestions.

@gante gante changed the title Add TF Speech2Text Add TFSpeech2Text Jan 28, 2022
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice addition, great work!

src/transformers/generation_tf_utils.py Show resolved Hide resolved
@@ -756,6 +759,8 @@ def generate(
)
# expand encoder_outputs
encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0),)
if "attention_mask" in locals(): # vision models don't have this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is a bit weird. "attention" mask is in the keyword arguments of this function so it will show up in the locals. Maybe check if it's not set to None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent point, for some reason I didn't notice that it was in the keyword arguments (and thus the current version would fail) 👍 Checking against None now

Comment on lines 878 to 880
# past drops the `encoder_outputs` during the loop and it may be needed (encoder-decoder models)
if len(past) > 1 and encoder_outputs is not None:
past = (encoder_outputs, past)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's double check with @patrickvonplaten for this change.

Copy link
Member Author

@gante gante Jan 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reworked this part to be less intrusive: the past variable is not touched, encoder_outputs is passed to prepare_inputs_for_generation a few lines below. This means that it will be ignored as part of kwargs in most models, and encoder-decoder models can now explicitly access it if they wish to do so (like TFSpeech2Text does).

This change also solves the assert comments below, in prepare_inputs_for_generation.

src/transformers/modeling_tf_pytorch_utils.py Show resolved Hide resolved
self.num_heads = num_heads
self.dropout = tf.keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to fix this assert in the original model (no new assert in the modeling files normally, but this is a copied from).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied the PT check (that raises a ValueError) to both models

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it turns out that many models used a copy of this 😅 Changed them all 👍

Comment on lines 455 to 456
if model_class.__name__ in ["TFSpeech2TextModel", "TFSpeech2TextForConditionalGeneration"]:
inputs = {
"decoder_input_ids": tf.keras.Input(
batch_shape=(2, max_input),
name="decoder_input_ids",
dtype="int32",
),
"input_ids": tf.keras.Input(batch_shape=(2, 32, max_input), name="input_ids", dtype="float32"),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there another (more general) check we could use here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By more general do you mean instead of hardcoding the possible names in model_class.__name__? We have TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, but it doesn't cover TFSpeech2TextModel :(

tests/test_modeling_tf_common.py Outdated Show resolved Hide resolved
@@ -756,6 +759,8 @@ def generate(
)
# expand encoder_outputs
encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0),)
if attention_mask is not None: # vision models don't have this
attention_mask = tf.gather(attention_mask, expanded_batch_idxs, axis=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only needed for the new model or was it a bug to not have done this previously?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried about this as it could lead to some unexpected errors in the more complicated models like TFRag. Could we maybe run the following tests here to double check that this change is ok:

RUN_SLOW=1 pytest tests/test_modeling_tf_bart.py
RUN_SLOW=1 pytest tests/test_modeling_tf_t5.py
RUN_SLOW=1 pytest tests/test_modeling_tf_rag.py (not that you need some more dependencies to be installed here)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spot on, it failed for those models.

I've reworked these lines to behave like pytorch's _expand_inputs_for_generation(), which is used in all modalities. I've tested the changes against bart, bert, t5, vit, and s2t -- let me know what you think, @patrickvonplaten

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Could you please also run the slow TFRag tests - it's a bit annoying as it's a huge and slow test, but it's the model that is most prone to break if there are changes in tf generate

Copy link
Member Author

@gante gante Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm that it passes TFRag tests as well (had to spin up a larger machine 😅 )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Good for me then

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me in general! Maybe just two final things before merging:

  • Quickly run some slow tests to make sure we haven't broken generate accidently for some of the more complicated models
  • Should we maybe leave past for now until we improve the naming of generate more globally?

patil-suraj and others added 10 commits February 1, 2022 16:57
…gface#15262)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
* Fix TF Causal LM models' returned logits

* Fix expected shape in the tests

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
…nizerBase` `__init__` (huggingface#15454)

* replace assert with exception for `padding_side` arg in `PreTrainedTokenizerBase` `__init__`

* add test

* fix kwargs

* reformat test

* format

* format

* fix typo to render the documentation
…st version is available (huggingface#15319)

* add new test

* update test

* remove `tokenizer_file` from `additional_files_names` in `tokenization_utils_base.py`

* add `tokenizer_file` for the fast only tokenizer

* change global variables layoutxml

* remove `"tokenizer_file"` from DPR tokenizer's Global variables

* remove `tokenizer_file` from herbert slow tokenizer init

* `"tokenizer_file"` from LED tokenizer's Global variables

* remove `tokenizer_file` from mbart slow tokenizer init

* remove `tokenizer_file` from slow tokenizer template

* adapt to versioning

* adapt the `test_tokenizer_mismatch_warning` test

* clean test

* clarify `VOCAB_FILES_NAMES` in tokenization_utils_fast.py

* Revert "remove `tokenizer_file` from mbart slow tokenizer init"

This reverts commit 0dbb723.

* Revert "`"tokenizer_file"` from LED tokenizer's Global variables"

This reverts commit 5a3f879.

* Revert "remove `tokenizer_file` from herbert slow tokenizer init"

This reverts commit f5e1000.

* Revert "remove `"tokenizer_file"` from DPR tokenizer's Global variables"

This reverts commit da08953.

* set `tokenizer_file` in super `__init__` of mbart
@Rocketknight1
Copy link
Member

Cloned your branch and did some experimentation and LGTM! I noticed one issue - model.save() seems to encounter problems, but I'm not totally sure of why, or whether it's limited to this model or not. save_pretrained and save_weights worked correctly, and saving in SavedModel format has always been a bit shaky for us, so this isn't critical.

@@ -1131,8 +1119,9 @@ def _generate_beam_search(
done = [False for _ in range(batch_size)]

while cur_len < max_length:
# import pdb; pdb.set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we delete this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smh not using breakpoint() in the year 2022

shape=(-1,),
)
# prepares text-based inputs
if len(shape_list(input_ids)) == 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly worried here about the TF vision tests. Can you check this one as well:

RUN_SLOW=1 pytest tests/test_modeling_tf_vision_encoder_decoder.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can confirm that they pass 👍

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One final vision tests as mentioned before and then it's good for me as well

@gante
Copy link
Member Author

gante commented Feb 4, 2022

Pending the results of automated tests, this should be the last planned commit. We already have facebook/s2t-small-librispeech-asr as a TF model, and I will upload the TF version of the others today.

@patrickvonplaten -- One important final change that I'd like to ask for a double-check is the removal of the positional embeddings weights from the nn.Parameter() wrapper, in the PT model. It is a constant that was not being saved nor loaded, and was causing issues in the pt_tf tests (the TF model had no such variable, and technically it is not a parameter).

To check the changes, I've run:

RUN_SLOW=1 pytest tests/test_modeling_tf_vision_encoder_decoder.py
RUN_SLOW=1 pytest tests/test_modeling_tf_bart.py
RUN_SLOW=1 pytest tests/test_modeling_tf_t5.py
RUN_SLOW=1 pytest tests/test_modeling_tf_rag.py
RUN_SLOW=1 pytest tests/test_modeling_speech_to_text_2.py
RUN_PT_TF_CROSS_TESTS=1 RUN_SLOW=1 pytest tests/test_modeling_speech_to_text.py
RUN_PT_TF_CROSS_TESTS=1 RUN_SLOW=1 pytest tests/test_modeling_tf_speech_to_text.py

EDIT: TF models uploaded.

@@ -153,9 +153,9 @@ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Opt
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)

self.weights = nn.Parameter(emb_weights)
self.weights = emb_weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj - is this change ok? Why exactly do we have to do this here actually?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we can just save a tensor like this in PyTorch. When doing model.to(device) does this tensor correctly move to the device as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand why this would cause an issue in TF<=>PT exactly - what was the issue? BTW, we always wrap the static sinusoidal encodings in a nn.Parameter() or nn.Embedding(...), E.g. Vit's positional embeddings, DistilBERT's position embeddings. Why should we do it differently here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But maybe it's totally fine. What do you think @LysandreJik @sgugger ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the issue?

This is the only model that I could find that may change the size of the embedding layer in the forward call (here). The TF version for other embeddings allocates these named model parameters in build(), which is expected to happen before the first call(). Since nn.Parameter() is mostly a wrapper to add the variable to the model's list of parameters (AFAIK), then dropping it means we don't need a named parameter in TF, which means we can simplify the code and avoid solving the resizing issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, but dropping it also removes it from the checkpoint, no? This might prevent proper loading/saving AFAICT.

Copy link
Member Author

@gante gante Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Facebook checkpoints do not have it and they load correctly in TF and PT regardless of this change. I'm assuming that, for the same reason, we exclude errors when loading a checkpoint in PT and this variable doesn't exist, as it is in _keys_to_ignore_on_load_missing (here).

The issue that prompted me to do this change was uncovered in the equivalence tests, where we try to load in TF a PT model that does not come from a checkpoint (here). We also have this var in _keys_to_ignore_on_save, but the test attempts to load a PT model in TF without saving it first -- i.e., having the var in nn.Parameter() breaks the test because it is not excluded, contrarily to the checkpoints.

Copy link
Collaborator

@sgugger sgugger Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, someone who fine-tuned (or just saved and reshared) a Speech2Text model will have this weight in the checkpoint. I don't know how it's used exactly, so I don't know if it's breaking to ignore it (we would have to put it in the _keys_to_ignore_on_load_unexpected).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sorted on Slack -- conclusion: reverting this change since it is needed for the to method)

@gante
Copy link
Member Author

gante commented Feb 7, 2022

@patrickvonplaten reverted the previous change and added the embedding weights as named variables in TF

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on it!

@gante gante merged commit 8406fa6 into huggingface:master Feb 8, 2022
@gante gante deleted the add_tf_s2t branch February 8, 2022 16:27
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Feb 18, 2022
* Add wrapper classes

* convert inner layers to tf

* Add TF Encoder and Decoder layers

* TFSpeech2Text models

* Loadable model

* TF model with same outputs as PT model

* test skeleton

* correct tests and run the fixup

* correct attention expansion

* TFSpeech2Text pask_key_values with TF format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
TensorFlow Anything TensorFlow
Projects
None yet
Development

Successfully merging this pull request may close these issues.