-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF port of ESM #19587
TF port of ESM #19587
Conversation
ab774a6
to
7508d13
Compare
The documentation is not available anymore as the PR was closed or merged. |
Pipeline tests are failing because the model has no SEP token and doesn't work with multiple sequences. Working on it! |
There's one final test remaining that's failing because of some arcane issue in the code that generates data batches for the pipeline. I'm trying to figure it out! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very clean, thanks a lot for porting this model in TensorFlow!
@@ -42,12 +42,14 @@ | |||
|
|||
logger = logging.get_logger(__name__) | |||
|
|||
_CHECKPOINT_FOR_DOC = "facebook/esm-1b" | |||
_CHECKPOINT_FOR_DOC = "Rocketknight1/esm2_t6_8M_UR50D" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need an update :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, all of these will be moved to facebook
before the next release!
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition | ||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 | ||
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, | ||
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those long comments make review very hard in GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That one's copied from BERT!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth fixing on a followup PR then!
ed2c1e4
to
f5fbfb9
Compare
Tests are green, and #19124 has been merged! Going to use it to upload the remaining checkpoints and then merge this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥🔥🔥
(Now that I've reviewed this PR, does it mean I can get a job in the biotech industry? :P )
# Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation | ||
# and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at | ||
# all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the | ||
# original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that | ||
# the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our | ||
# models give different outputs from the original. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I got it right: we want to load inv_freq
as a weight when it exists, because it was stored in float16. If we were to use the float32 values, we would get different outputs. Correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also - does XLA automatically create constant caches when appropriate? 😱
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it does! And if not, it can compute this during the 'downtime' of other small tasks once it's compiled - it's a really small tensor!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, you're correct about the float16/float32 issue. I was getting divergent outputs in my port at first because I recomputed the value rather than loading it from the checkpoint.
def set_input_embeddings(self, value: tf.Variable): | ||
self.embeddings.weight = value | ||
self.embeddings.vocab_size = shape_list(value)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that get_input_embeddings
returns self.embeddings.word_embeddings
, I'm assuming that this function should overwrite self.embeddings.word_embeddings
and value
is of type Embedding
- right?
(like set_output_embeddings
below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, good catch!
@@ -0,0 +1,287 @@ | |||
# coding=utf-8 | |||
# Copyright 2020 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs an update :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
* Partial TF port for ESM model * Add ESM-TF tests * Add the various imports for TF-ESM * TF weight conversion almost ready * Stop ignoring the decoder weights in PT * Add tests and lots of fixes * fix-copies * Fix imports, add model docs * Add get_vocab() to tokenizer * Fix vocab links for pretrained files * Allow multiple inputs with a sep * Use EOS as SEP token because ESM vocab lacks SEP * Correctly return special tokens mask from ESM tokenizer * make fixup * Stop testing unsupported embedding resizing * Handle TF bias correctly * Skip all models with slow tokenizers in the token classification test * Fixing the batch/unbatcher of pipelines to accomodate the `None` being passed around. * Fixing pipeline bug caused by slow tokenizer being different. * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update set_input_embeddings and the copyright notices Co-authored-by: Your Name <you@example.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Working out the last few issues now! Models <3B parameters have been ported already, larger models will need to wait for #19124.
This PR also includes fixes for a couple of issues in the original PyTorch ESM.