Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT-J-6B #13022

Merged
merged 133 commits into from
Aug 31, 2021
Merged

GPT-J-6B #13022

merged 133 commits into from
Aug 31, 2021

Conversation

StellaAthena
Copy link
Contributor

@StellaAthena StellaAthena commented Aug 6, 2021

What does this PR do?

Introduces the long awaited GPT J model class to HuggingFace! Concurrently with this PR being merged I will make a GPT J 6B checkpoint public on the EleutherAI HF page for people to use. The model has been evaluated as being within error tolerances of the GPT J 6B model we released in Jax two months ago.

@patil-suraj was very helpful in assisting me to understand HF philosophy and how to make this PR most in line with the rest of the codebase. Other than that, the major design consideration was to make the configs compatible with GPT-2 rather than GPT-Neo. GPT-Neo has some usability limitations due to its configs having names unrelated to GPT-2’s (see #12183 for details). Given those problems and my hope that GPT-Neo will have it’s configs updated in the future, it seemed like a clear choice to align GPT J with GPT-2.

Shout outs to @finetuneanon whose implementation this one is based off of, as well as @kumuruz for assistence optimizing and debugging.

Supersedes #12243 #13010 #13022

Closes #12098

Before submitting

Who can review?

@EricHallahan
Copy link
Contributor

@patrickvonplaten The script currently within Mesh Transformer JAX builds a split checkpoint in the format used by the @finetuneanon fork of transformers, and I needed to heavily modify it to generate checkpoints in the HF format. @kingoflolz has already asked me to make a PR to update the script.

@EricHallahan
Copy link
Contributor

EricHallahan commented Aug 25, 2021

If we feel that the solution I outlined is the best solution, I can put that plan into action and update the repo on Model Hub. Maybe vote 👍/👎 on this?

@EricHallahan
Copy link
Contributor

@g-karthik I have ported over the experimental parallelization code from GPT-2 into GPT-J. I wouldn't personally recommend using it to anyone unless they need to, but it should work in a pinch when a better solution is unavailable.

(I note that this bears no resemblance to the implementation of model pararallelism in Mesh Transformer JAX and should not be thought as an equivalent implementation or replacement for that implementation.)

@oborchers
Copy link

oborchers commented Aug 28, 2021

@EricHallahan: Thank you very much for the awesome work on the issue!! there is one thing to remark regarding the 16 VRAM configuration:

As far as I can tell, even with the floating point revision set correctly (and also when only the local files are fp16, local_files_only=True), the model will still be loaded in float32 until model.half() is called, thus requiring the 23G RAM to be available before the model.half() and before the model.to(device) is called.

By extension this means, that a text generation pipeline cannot be loaded with device=0 in a VRAM<23G setting, as .half() isn’t called automatically anywhere. In this case the model must be loaded, .halved(), and then passed to the pipeline via the argument.

Correct me if this observation is wrong. Is there any way of loading/moving the model iteratively to GPU so that the 23G RAM limitation can be circumvented, similar as done in @finetuneanon repository? (Probably out of scope for this very PR, but likely a problem for larger models in general in the future). Presumably this can be done using the state dict, but I‘m not deep enough into the inner working to judge this.

Also tagging @patrickvonplaten

@EricHallahan
Copy link
Contributor

@oborchers Yes, we have had multiple people test this via Colab and they have reported the same issue.

I have verified that choosing the "float16" revision loads the model at float32. I don't understand why it doesn't load the model at half precision, especially because I explicitly set "torch_dtype": "float16" in the model config on the float16 branch. Maybe I am interpreting the fuctionality of that config parameter wrong, but my understanding is that it explicitly tells the model loader to use the specified type as the default.
(I also want to make sure to point out that naively loading the model is not particularly useful at this time, as the Model Hub repo still has an extraneous main branch that I have been unable to remove and replace with the float32 branch.)

The multi-part loading scheme used by the @finetuneanon fork was purposefully built to bypass the suboptimal way that transformers loads checkpoints so that resource-constrained systems could load GPT-Neo (and later GPT-J) without running out of memory. In order to meet the requirements for integration into transformers we had to adapt that code to instead use the existing single-file checkpoint format. It is up to the transformers maintainers to consider an alternative/optimized checkpoint loading pipeline, and I assume that such a system would need a separate PR considering the changes probably needed to PretrainedModel.

@patrickvonplaten
Copy link
Contributor

@oborchers Yes, we have had multiple people test this via Colab and they have reported the same issue.

I have verified that choosing the "float16" revision loads the model at float32. I don't understand why it doesn't load the model at half precision, especially because I explicitly set "torch_dtype": "float16" in the model config on the float16 branch. Maybe I am interpreting the fuctionality of that config parameter wrong, but my understanding is that it explicitly tells the model loader to use the specified type as the default.
(I also want to make sure to point out that naively loading the model is not particularly useful at this time, as the Model Hub repo still has an extraneous main branch that I have been unable to remove and replace with the float32 branch.)

The multi-part loading scheme used by the @finetuneanon fork was purposefully built to bypass the suboptimal way that transformers loads checkpoints so that resource-constrained systems could load GPT-Neo (and later GPT-J) without running out of memory. In order to meet the requirements for integration into transformers we had to adapt that code to instead use the existing single-file checkpoint format. It is up to the transformers maintainers to consider an alternative/optimized checkpoint loading pipeline, and I assume that such a system would need a separate PR considering the changes probably needed to PretrainedModel.

Thanks a lot for the detailed message here. What we currently do in .from_pretrained(...) is definitely suboptimal if one is sure that all the loaded parameters are correct and complete. What happens under-the-hood is that:

  1. A random model with the correct configuration is instantiated meaning all layers are randomly initialized as defined in the config. Random initialization is always happening in fp32.
  2. Then the state_dict is loaded (the correct weights)
  3. All layers of the random model are compared to the "real" layers that can be found in the state dict
  4. All weights of the layers that are present in the state dict are dropped and overwritten by the state dict
  5. All layers are casted to the correct dtype (as defined by torch_dtype)

=> we have this logic mainly for models like BERT for which one would load the "base"-model and than add a randomely initialized head for the specific downstream task. It becomes quite clear however that this make less sense for GPT-like models.
There is an open issue to solve this problem: #12274 but I don't think it'll be that easy to solve.
If ok with you guys (@EricHallahan @StellaAthena) we would merge the PR and then try to fast-track the issue - what do you think?

@StellaAthena StellaAthena mentioned this pull request Aug 31, 2021
5 tasks
@oborchers
Copy link

oborchers commented Aug 31, 2021

FYI: Model acceleration for GPT-J via deepspeed in the making: microsoft/DeepSpeed#1332

@patrickvonplaten
Copy link
Contributor

Important

We will merge GPT-J now to master. Note that at the moment GPT-J cannot be run on a free google colab GPU since loading the model weights in fp16 requires too much CPU RAM. At the moment one needs at least 26 GB of CPU RAM in order to load GPT-J in fp16-precision. We are working on fixing the problem so that in a next step one can load GPT-J with just 12 GB of CPU of RAM.

@patrickvonplaten patrickvonplaten merged commit c02cd95 into huggingface:master Aug 31, 2021
@EricHallahan
Copy link
Contributor

I feel the need to reiterate that there remains a redundant main branch in Model Hub that is neither the true single precision checkpoint found in float32 nor the half precision checkpoint found in float16. This means that naive usage (i.e. not specifying revision="float32" or revision="float16") will not download the proper checkpoints.

@patrickvonplaten
Copy link
Contributor

reiterate

Thanks for letting me know - is it ok if I put the "correct fp32 weigts" in the main branch for now? Or do you prefer "fp16"? Both are fine with us :-) Think we can't completely delete the "main" branch for now (cc @LysandreJik)

@EricHallahan
Copy link
Contributor

Think we can't completely delete the "main" branch for now

That is my understanding.

is it ok if I put the "correct fp32 weigts" in the main branch for now? Or do you prefer "fp16"? Both are fine with us :-)

Putting the single precision weights in main should be fine for now.

@StellaAthena
Copy link
Contributor Author

Think we can't completely delete the "main" branch for now

That is my understanding.

is it ok if I put the "correct fp32 weigts" in the main branch for now? Or do you prefer "fp16"? Both are fine with us :-)

Putting the single precision weights in main should be fine for now.

+1 this

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Aug 31, 2021

Ok great - just uploaded the correct weigths to "main". You can see that the sha256 between "main": https://huggingface.co/EleutherAI/gpt-j-6B/blob/main/pytorch_model.bin and "float32" https://huggingface.co/EleutherAI/gpt-j-6B/blob/float32/pytorch_model.bin match now :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🌟 New model addition - GPT-J-6B