-
Notifications
You must be signed in to change notification settings - Fork 309
Port models to core #1119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port models to core #1119
Conversation
ianstenbit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor questions for you, but looks good!
Way to go Matt 🎊 🥳 ❗
| self.backbone.get_layer("token_embedding").embeddings, | ||
| transpose_b=True, | ||
| ) | ||
| logits = self.get_layer("reverse_embedding")(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, as I see this is done throughout -- what's the purpose of using self.get_layer instead of storing the layer as a member on the object directly e.g. self.reverse_embedding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we are functional, they actually aren't self.layer or self.backbone.layer accessors.
| ) | ||
|
|
||
| @pytest.mark.large # Saving is slow, so mark these large. | ||
| @pytest.mark.tf_only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because of the tokenizer that this is TF only? If so, can we potentially make the test work on other backends?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Same throughout the PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because a string model inputs are only supported in tf. So a model that includes tokenization inside call really only makes sense in tf.
I am also open to just deleting these tests. I think saving for backbone models and tasks are important, for tokenizers saving directly like this less so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could potentially rewrite the test with preprocessor=False to make it generic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is actually a test for just the tokenizer. So preprocessor=False would not apply here.
fchollet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
jbischof
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
| ) | ||
|
|
||
| @pytest.mark.large # Saving is slow, so mark these large. | ||
| @pytest.mark.tf_only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could potentially rewrite the test with preprocessor=False to make it generic
We want to avoid a bug in the case that looks like model.generate() mode.fit() model.generate() In this case we need to be careful to not pull in the cached variable state at generation compile time.
* Port models to core * Proper seed generation for jax * Don't test metrics yet (for a separate PR) * Add all model variables to the jax state mapping We want to avoid a bug in the case that looks like model.generate() mode.fit() model.generate() In this case we need to be careful to not pull in the cached variable state at generation compile time. * Address Ian's comments * Add TODO's for revers embedding * Run pytest on the entirety of keras-nlp * Misc cleanups * Mark docstring tests tf only * Last failing doctest
🚧 This is an experimental feature branch, more details soon. 🚧
I'm temporarily basing this on
preprocessingbranch so we can start review.