-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TF whisper #19378
Add TF whisper #19378
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥 🔥 🔥
I have a bunch of [XLA Notes]
comments -- they don't need to be addressed now. I've wrote them down as potential sources of problems for XLA :D
indices=[[i, token] for i in range(scores.shape[0]) for token in self.begin_suppress_tokens], | ||
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[XLA notes] I'm suspicious about these lines, list comprehensions at call time often cause problems.
indices: A more common pattern here is to build indices with TF functions like tf.tile
and tf.concat
.
updates: tf.ones_like((scores.shape[0] * len(self.begin_suppress_tokens)), dtype=tf.float32) * -float("inf")
would work here :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI post XLA fixing: this wasn't a problem 👍
indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens], | ||
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[XLA notes] same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI post XLA fixing: this wasn't a problem 👍
other tokens to `-inf` so that they are sampled at their corresponding index.""" | ||
|
||
def __init__(self, force_token_map): | ||
self.force_token_map = dict(force_token_map) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[XLA notes] I'm also skeptical that dict .get()
works with XLA. We might want to convert this to a flat tensor, with negative tokens in the empty positions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm almost certain that anything like get()
or a dictionary lookup will be done once when the function is traced, with any arguments treated as constants, rather than in each loop iteration with the arguments treated as variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI post XLA fixing: this was indeed a problem 😱
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM other than maybe XLA and the shared embedding
indices=[[i, token] for i in range(scores.shape[0]) for token in self.begin_suppress_tokens], | ||
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting!
return_dict=return_dict, | ||
) | ||
# Decoder and encoder embeddings are tied | ||
lm_logits = tf.matmul(outputs[0], self.model.get_input_embeddings().weights, transpose_b=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here would just remind my previous comment about TFSharedEmbedding
that has a linear mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the TFSharedEmbedding
layer is flagged for being deleted cc @gante
I've tidied up the call a little bit though. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I focused on the core model code and comparing it to the PT implementation, since it seemed like @gante was handling generation and XLA compatibility. In the one place where I thought I'd found an error (missing padding in the Conv1Ds
) that was being handled in the call()
. I added a suggestion for a comment there, but other than that the core model code LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correction to the forced ids logits processor test :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding the port so quickly 💪
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice!
def decoder(self): | ||
return self.model.decoder | ||
|
||
def encoder(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really relevant for this PR but why is there both a encoder
and a get_encoder
function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. Tbh, I was just copying this to match the PT model and didn't think about it. Looking at other models e.g. bart it seems to be a common pattern in the codebase.
5fa46be
to
f320909
Compare
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
46728e3
to
53e4627
Compare
What does this PR do?
Adds TF Whisper port of PyTorch implementation
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.