Skip to content

Conversation

@stancld
Copy link
Contributor

@stancld stancld commented Feb 11, 2022

What does this PR do?

This PR adds a TensorFlow implementation of GPT-J models

Fixes #15583

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed.

@LysandreJik @patrickvonplaten

@HuggingFaceDocBuilder
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@stancld stancld marked this pull request as ready for review February 19, 2022 20:46
@stancld stancld changed the title [WIP] Add TF implementation of GPT-J Add TF implementation of GPT-J Feb 19, 2022
@LysandreJik
Copy link
Member

Super exciting! cc @Rocketknight1 and @gante

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the TF version!
The modeling code looks good to me! Will let @Rocketknight1 and @gante review the TF side of things :-)

class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
@tooslow
def test_lm_generate_gptj(self):
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use TF checkpoint here. Once the PR is approved I will upload the TF checkpoint.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 because cross-loading PyTorch checkpoints requires torch to be installed, so we should make conversions whenever possible.

* Adjust split and merge heads to handle 4 and 5-dim tensors

* Fix use_cache according to PT implementation

* Add some missing comments

* Fix formattin of expected_output_ids in the test file
@gante
Copy link
Contributor

gante commented Mar 11, 2022

@gante Bit late but working on the PR now :]

lovely 👌 lmk if you need a hand with anything, let's push this beauty to the finish line

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the two issues where you have tagged me, it seems good to go 🔥 I will have a look at those.

@patil-suraj
Copy link
Contributor

patil-suraj commented Mar 15, 2022

I have uploaded the TF weights https://huggingface.co/EleutherAI/gpt-j-6B/blob/main/tf_model.h5

@gante Do you know how to save the fp16 weights in TF ? We need those for the float16 branch.

@gante
Copy link
Contributor

gante commented Mar 15, 2022

@gante Do you know how do you save the fp16 weights in TF ? We need those for the float16 branch.

@patil-suraj negative 😬 @Rocketknight1, do you know how to do it?

@Rocketknight1
Copy link
Member

Not easy to do in TF, unfortunately! Keras really wants you to do mixed precision, with full-precision weights. I don't think there's any "native" way to convert a Model to use float16 weights except with TFLite stuff, or just manually converting the weight arrays yourself.

@patil-suraj
Copy link
Contributor

patil-suraj commented Mar 15, 2022

Okay, thanks for the answer.

or just manually converting the weight arrays yourself.

so should we manually create fp16 weights for the float16 branch ? Not sure if that affects anything in TF.

@Rocketknight1
Copy link
Member

I'm not sure, basically - I don't know how you'd load them into the model class without converting them to float32. Keras has some options for forcing dtypes but I don't think they're very well-supported, so as long as we want the model in Keras and not raw TF then I don't really know how to do this. Can we just drop the float16 branch for TF for now?

@patil-suraj
Copy link
Contributor

Gently pinging everyone here :) Is this PR good for merge ?

@gante
Copy link
Contributor

gante commented Mar 22, 2022

There are two outstanding issues that require light changes (@stancld):

  1. Add TF implementation of GPT-J #15623 (comment)
  2. Add TF implementation of GPT-J #15623 (comment)

Other than that, good to go IMO 👍

@stancld
Copy link
Contributor Author

stancld commented Mar 25, 2022

There are two outstanding issues that require light changes (@stancld):

  1. Add TF implementation of GPT-J #15623 (comment)
  2. Add TF implementation of GPT-J #15623 (comment)

Other than that, good to go IMO 👍

@gante I'm gonna solve these issues now O:]

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 25, 2022

The documentation is not available anymore as the PR was closed or merged.

* Update set/get output embeddings method
* Update prepare prepare_inputs_for_generation
* Skip test_resize_token_embeddings as this part of code is going to undergo a major refactor
* Update outputs for @tooslow tests
@stancld
Copy link
Contributor Author

stancld commented Mar 25, 2022

@gante All the remaining issues/comments should be resolved now :] Thanks a lot for your guidance! O:]

@gante
Copy link
Contributor

gante commented Mar 25, 2022

@stancld amazing, thank you so much for this contribution! Adding these models is always a tough task, especially the last mile, so I appreciate your effort 🤗 and I'm sure the community does as well.

@patil-suraj @Rocketknight1 are you cool with me merging this PR? We won't be able to call resize_token_embeddings (and I have an action point to enable that), but other than that seems good to go!

@patil-suraj
Copy link
Contributor

Good to merge for me! Maybe just override the resize_token_embeddings method and raise NotImplementedError with a message. But don't feel strongly about this.

@gante
Copy link
Contributor

gante commented Mar 25, 2022

(@Rocketknight1 approved on slack, merging now to get a boost on Friday afternoon good vibes)

@gante gante merged commit ed2ee37 into huggingface:main Mar 25, 2022
@stancld stancld deleted the tf_gpt-j branch March 26, 2022 12:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add TF implementation of GPT-J model

7 participants