Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significantly increased VRAM usage for Mixtral qlora training compared to 4.36.2? #28339

Closed
1 of 4 tasks
DocShotgun opened this issue Jan 4, 2024 · 5 comments
Closed
1 of 4 tasks

Comments

@DocShotgun
Copy link

System Info

The environment is a Runpod container with python 3.10, single A100 80gb, transformers 4.37.0dev (3cefac1), using axolotl training script (https://github.com/OpenAccess-AI-Collective/axolotl).

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hello, just tried doing a training run on the dev version of transformers (as of 3cefac1) via the common training repository axolotl (https://github.com/OpenAccess-AI-Collective/axolotl) and noticed that I went OOM using the same configuration that I had previously used successfully with transformers 4.36.2 stable. And not even just a small difference - I had to reduce my batch size by 4x to make the training fit in VRAM.

I was previously able to fit 8192 ctx, batch size 4, grad accum steps 2 without difficulty, but I found that I now had to reduce my batch size to 1 to avoid OOM. The relevant training hyperparameters are:

load_in_4bit: true
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
optimizer: adamw_bnb_8bit
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
flash_attention: true

no deepspeed or fsdp
no evals

Would appreciate any insights into what caused the massive increase in memory usage. I noticed that ehartford's latest dolphin 2.7 qlora used a batch size of 3 per device at 16k ctx on A100 80gb, so surely I'm missing something here?

Expected behavior

The training run should take a relatively similar amount of VRAM as it did previously with the same config.

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for the report, here are potential PRs that I would suspect:

@MB7979
Copy link

MB7979 commented Jan 7, 2024

This may not be relevant to you but I found this recent change to Axolotl has made a significant difference to VRAM usage. Previously I could just squeeze in a LoRA on a 34B model on my 3x3090s at batch size 2, seq length 4096, now it OOMs immediately. I undid the change and it fits again.

@DocShotgun
Copy link
Author

This may not be relevant to you but I found this recent change to Axolotl has made a significant difference to VRAM usage. Previously I could just squeeze in a LoRA on a 34B model on my 3x3090s at batch size 2, seq length 4096, now it OOMs immediately. I undid the change and it fits again.

Hmm it's certainly possible since that commit was in between when I did my initial train and the run where I had to drop the batch size. Unfortunately don't have a training instance up right now, so I'd have to test it the next time I try to train.

@DocShotgun
Copy link
Author

I've determined that the cause of the increased VRAM usage was indeed axolotl changing the default for use_reentrant to False for gradient checkpointing. Going to go ahead and close the issue.

@ArthurZucker
Copy link
Collaborator

thanks for sharing the solution! 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants