New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient accumulation incurs 10GiB VRAM cost when mixed-precision is enabled #2035
Comments
Small update: As I don't hardware to replicate the original issue 1:1, I ran a simplified version based on |
@BenjaminBossan here's a smaller repro with Pythia 1.4b. it should fit into <16GiB of VRAM. if that's still too big, then it can be tried with smaller Pythia models, but the smaller it goes, the harder it will be to tell the difference between fixed overheads versus per-param overheads. To test this smaller-footprint repro on my 2xA40: I used mixed fp32/bf16, 2 microsteps gradient accumulation (14.8GiB NVML usage): ACCELERATE_MIXED_PRECISION=bf16 python -m qlora \
--disable_tqdm True \
--model_name_or_path EleutherAI/pythia-1.4b-deduped \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 2 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
mixed fp32/bf16, no gradient accumulation (13.6GiB NVML usage): ACCELERATE_MIXED_PRECISION=bf16 python -m qlora \
--disable_tqdm True \
--model_name_or_path EleutherAI/pythia-1.4b-deduped \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 1 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
That's a difference of full fp32, 2 microsteps gradient accumulation (12.0GiB NVML usage): python -m qlora \
--disable_tqdm True \
--model_name_or_path EleutherAI/pythia-1.4b-deduped \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 2 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
full fp32, no gradient accumulation (11.6GiB NVML usage): python -m qlora \
--disable_tqdm True \
--model_name_or_path EleutherAI/pythia-1.4b-deduped \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 1 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
That's a difference of |
@Birch-san still working on this, are you adjusting the batch size during this? (it's a little unclear) E.g. with |
I'm using Python 3.11.2, as stated in the system info in the issue. I am not adjusting the batch size to compensate. I am always leaving batch size at default ( |
You must because otherwise we're not measuring the effective batch sizes properly here. For example, this would be the equivalent of me saying that a bs of 16 has less memory than a bs of 32 effectively, which is true. Why are we not doing it for this? (if I'm reading/understanding wrong, that's okay) |
Working on being able to setup my environment to properly reproduce and I'll have more comments as I'm able to play with the code |
I thought the point of gradient accumulation is that (half-precision) gradients from each microstep are accumulated into the same (full-precision?) buffer, so it doesn't matter how many microsteps of gradient accumulation you perform: peak memory usage doesn't change. you only need enough memory to survive the microstep.
activations scale with batch size, but again this is just a concern for the peak memory usage within the microstep. the size of the weight update doesn't get bigger; you need an update per param, not per sample. after you compute a microbatch of 16 and accumulate the gradients into a "weight update" buffer: no further memory is required to compute another microbatch of 16. every buffer you used for the first microstep can be re-used. |
That makes sense, yes. Sorry blanked on this. Still working on getting access to compute to run your code in bf16, however as Benjamin stated, using raw gradient accumulation didn't see this. (I saw this as well when running the |
@muellerzr if bf16 is a blocker: I confirm that I get the same numbers using mixed fp32/fp16 too. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
You will need ~71000MiB of VRAM for this experiment (e.g. 2xA40).
Setup our fork of qlora (which adds support for keeping the model in fp32, for mixed-precision full-finetunes, and adds a memory callback):
git clone https://github.com/scottlogic-alex/qlora.git cd qlora pip install -r requirements.txt
Run training for a couple of steps to see the memory measurements. The memory is reported via HF transformers Trainer's
on_step_end
callback, which occurs after all gradient accumulation microsteps have completed.mixed fp32/bf16, 2 microsteps gradient accumulation (68.7GiB):
mixed fp32/bf16, no gradient accumulation (59.0GiB):
full-fp32, 2 microsteps gradient accumulation (53.4GiB):
full-fp32, no gradient accumulation (52.3GiB):
In summary:
with mixed fp32/bf16: enabling gradient accumulation increased VRAM usage 59.0->68.7GiB (that's 9.7GiB, or 1.54 bytes/param).
with pure-fp32: enabling gradient accumulation increased VRAM usage 52.3->53.4GiB only, which is not nothing (0.9GiB), but closer to what we'd expect (i.e. zero-cost).
side-note: it is also surprising that mixed fp32/bf16 costs 6.7GiB (1.075 bytes/param) more VRAM than pure-fp32.
in mixed-precision: we spend 2 bytes/param to make a half-precision compute copy, but we're supposed to get this back and then some, because our gradients get 2 bytes/param smaller, and our activations get smaller too (admittedly this is a batch-of-1, sequence-of-8 with checkpointing enabled, so our activations are not as significant). but mixed-precision is supposed to be smaller, not 1 byte/param bigger than full-fp32.
Expected behavior
in mixed fp32/bf16 mode: enabling gradient accumulation should cost no extra VRAM. or at least only have the same amount of overhead as it does in full-fp32 mode.
also: mixed fp32/bf16 training should not cost more VRAM than full-fp32 training.
The text was updated successfully, but these errors were encountered: