-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Add TF implementation of GPT-J #15623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Super exciting! cc @Rocketknight1 and @gante |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding the TF version!
The modeling code looks good to me! Will let @Rocketknight1 and @gante review the TF side of things :-)
| class TFGPTJModelLanguageGenerationTest(unittest.TestCase): | ||
| @tooslow | ||
| def test_lm_generate_gptj(self): | ||
| model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use TF checkpoint here. Once the PR is approved I will upload the TF checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 because cross-loading PyTorch checkpoints requires torch to be installed, so we should make conversions whenever possible.
* Adjust split and merge heads to handle 4 and 5-dim tensors * Fix use_cache according to PT implementation * Add some missing comments * Fix formattin of expected_output_ids in the test file
lovely 👌 lmk if you need a hand with anything, let's push this beauty to the finish line |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the two issues where you have tagged me, it seems good to go 🔥 I will have a look at those.
|
I have uploaded the TF weights https://huggingface.co/EleutherAI/gpt-j-6B/blob/main/tf_model.h5 @gante Do you know how to save the fp16 weights in TF ? We need those for the |
@patil-suraj negative 😬 @Rocketknight1, do you know how to do it? |
|
Not easy to do in TF, unfortunately! Keras really wants you to do mixed precision, with full-precision weights. I don't think there's any "native" way to convert a |
|
Okay, thanks for the answer.
so should we manually create fp16 weights for the |
|
I'm not sure, basically - I don't know how you'd load them into the model class without converting them to float32. Keras has some options for forcing dtypes but I don't think they're very well-supported, so as long as we want the model in Keras and not raw TF then I don't really know how to do this. Can we just drop the |
|
Gently pinging everyone here :) Is this PR good for merge ? |
|
There are two outstanding issues that require light changes (@stancld): Other than that, good to go IMO 👍 |
|
The documentation is not available anymore as the PR was closed or merged. |
* Update set/get output embeddings method * Update prepare prepare_inputs_for_generation * Skip test_resize_token_embeddings as this part of code is going to undergo a major refactor * Update outputs for @tooslow tests
|
@gante All the remaining issues/comments should be resolved now :] Thanks a lot for your guidance! O:] |
|
@stancld amazing, thank you so much for this contribution! Adding these models is always a tough task, especially the last mile, so I appreciate your effort 🤗 and I'm sure the community does as well. @patil-suraj @Rocketknight1 are you cool with me merging this PR? We won't be able to call |
|
Good to merge for me! Maybe just override the |
|
(@Rocketknight1 approved on slack, merging now to get a boost on Friday afternoon good vibes) |
What does this PR do?
This PR adds a TensorFlow implementation of
GPT-JmodelsFixes #15583
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed.
@LysandreJik @patrickvonplaten