Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TF whisper #19378

Merged
merged 204 commits into from
Oct 10, 2022
Merged

Add TF whisper #19378

merged 204 commits into from
Oct 10, 2022

Conversation

amyeroberts
Copy link
Collaborator

What does this PR do?

Adds TF Whisper port of PyTorch implementation

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 6, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 🔥 🔥

I have a bunch of [XLA Notes] comments -- they don't need to be addressed now. I've wrote them down as potential sources of problems for XLA :D

src/transformers/generation_tf_logits_process.py Outdated Show resolved Hide resolved
Comment on lines +525 to +526
indices=[[i, token] for i in range(scores.shape[0]) for token in self.begin_suppress_tokens],
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[XLA notes] I'm suspicious about these lines, list comprehensions at call time often cause problems.

indices: A more common pattern here is to build indices with TF functions like tf.tile and tf.concat.
updates: tf.ones_like((scores.shape[0] * len(self.begin_suppress_tokens)), dtype=tf.float32) * -float("inf") would work here :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI post XLA fixing: this wasn't a problem 👍

Comment on lines +543 to +544
indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens],
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[XLA notes] same as above

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI post XLA fixing: this wasn't a problem 👍

other tokens to `-inf` so that they are sampled at their corresponding index."""

def __init__(self, force_token_map):
self.force_token_map = dict(force_token_map)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[XLA notes] I'm also skeptical that dict .get() works with XLA. We might want to convert this to a flat tensor, with negative tokens in the empty positions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm almost certain that anything like get() or a dictionary lookup will be done once when the function is traced, with any arguments treated as constants, rather than in each loop iteration with the arguments treated as variables.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI post XLA fixing: this was indeed a problem 😱

src/transformers/generation_tf_utils.py Outdated Show resolved Hide resolved
tests/models/whisper/test_modeling_tf_whisper.py Outdated Show resolved Hide resolved
tests/models/whisper/test_modeling_tf_whisper.py Outdated Show resolved Hide resolved
tests/models/whisper/test_modeling_tf_whisper.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than maybe XLA and the shared embedding

Comment on lines +525 to +526
indices=[[i, token] for i in range(scores.shape[0]) for token in self.begin_suppress_tokens],
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting!

src/transformers/models/whisper/modeling_tf_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_tf_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_tf_whisper.py Outdated Show resolved Hide resolved
return_dict=return_dict,
)
# Decoder and encoder embeddings are tied
lm_logits = tf.matmul(outputs[0], self.model.get_input_embeddings().weights, transpose_b=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here would just remind my previous comment about TFSharedEmbedding that has a linear mode

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the TFSharedEmbedding layer is flagged for being deleted cc @gante

I've tidied up the call a little bit though. Let me know what you think.

tests/models/whisper/test_modeling_tf_whisper.py Outdated Show resolved Hide resolved
Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I focused on the core model code and comparing it to the PT implementation, since it seemed like @gante was handling generation and XLA compatibility. In the one place where I thought I'd found an error (missing padding in the Conv1Ds) that was being handled in the call(). I added a suggestion for a comment there, but other than that the core model code LGTM!

src/transformers/models/whisper/modeling_tf_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_tf_whisper.py Outdated Show resolved Hide resolved
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction to the forced ids logits processor test :)

tests/generation/test_generation_tf_logits_process.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the port so quickly 💪

src/transformers/models/whisper/modeling_whisper.py Outdated Show resolved Hide resolved
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice!

def decoder(self):
return self.model.decoder

def encoder(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really relevant for this PR but why is there both a encoder and a get_encoder function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Tbh, I was just copying this to match the PT model and didn't think about it. Looking at other models e.g. bart it seems to be a common pattern in the codebase.

src/transformers/models/whisper/modeling_tf_whisper.py Outdated Show resolved Hide resolved
@amyeroberts amyeroberts force-pushed the add-tf-whisper-rebase branch from 5fa46be to f320909 Compare October 10, 2022 10:57
@amyeroberts amyeroberts force-pushed the add-tf-whisper-rebase branch from 46728e3 to 53e4627 Compare October 10, 2022 12:22
@amyeroberts amyeroberts merged commit e3f028f into huggingface:main Oct 10, 2022
@amyeroberts amyeroberts deleted the add-tf-whisper-rebase branch October 10, 2022 13:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants