-
-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetuning 355M or larger GPT-2 models / Gradient Checkpointing #6
Comments
Are there any reference implementations for gradient checkpointing? I've heard it brought up as a PyTorch feature but I've not actually seen it in use. |
I imagine you've already tried something like this, but I've taken a look at the pytorch docs for implementation details for gradient checkpointing here and here. With normal pytorch modules it seems that it could be implemented during the forward pass using something like: import torch.utils.checkpoint
def forward(self, inputs):
return checkpoint(self.model,*inputs) You could make this forward pass conditional on the model chosen, or as an optional param to the train class. |
Yes, the correct implementation is something along those lines; apparently the Transformers GPT-2 I can give it another go. |
I am getting OOM for even the smaller 124M model if the input file is bigger than 100 mb. /usr/local/lib/python3.6/dist-packages/aitextgen/TokenDataset.py in init(self, file_path, vocab_file, merges_file, texts, line_by_line, from_cache, header, save_cache, cache_destination, compress, block_size, tokenized_texts, text_delim, bos_token, eos_token, unk_token, pad_token, progress_bar_refresh_rate, **kwargs) AttributeError: 'list' object has no attribute 'shape' |
The input dataset file is not related to these GPU OOM issues so you are hitting something else. You should not get OOM on the 124M model unless you are using a small GPU.
That's unrelated, but a legit bug. Filed at #49 |
My input file was around 20 gb and got those OOM. I broke it down with "split -b" into 100mb chunks and I have no issues running it now. |
@ganeshkrishnan1 after splitting the file into 100mb chunks you created several TokenDatasets and merged them? Or you just trained the model a little bit on every txt file separately? |
I trained the model a bit, saved it and then reloaded it again to train the next file. |
Awesome, thanks |
Gradient checkpointing currently works for me right now by just setting the GPT2Config property
In total this uses ~5GB of VRAM with a small training file |
Closing and unpinning due to 0.4.0 |
Gradient checkpointing must be implemented to avoid going OOM when finetuning those models.
That is apparently done at the training level and PyTorch has tricks to do it easily, but I am having difficulty getting it to work correctly.
The text was updated successfully, but these errors were encountered: