Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient accumulation incurs 10GiB VRAM cost when mixed-precision is enabled #2035

Closed
2 of 4 tasks
scottlogic-alex opened this issue Oct 5, 2023 · 11 comments
Closed
2 of 4 tasks
Assignees

Comments

@scottlogic-alex
Copy link

scottlogic-alex commented Oct 5, 2023

System Info

- `Accelerate` version: 0.22.0
- Platform: Linux-5.19.0-45-generic-x86_64-with-glibc2.35
- Python version: 3.11.2
- Numpy version: 1.25.2
- PyTorch version (GPU?): 2.1.0.dev20230802+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 62.79 GB
- GPU type: NVIDIA A40
- `Accelerate` default config:
  - compute_environment: LOCAL_MACHINE
  - distributed_type: FSDP
  - mixed_precision: bf16
  - use_cpu: False
  - debug: False
  - num_processes: 2
  - machine_rank: 0
  - num_machines: 1
  - rdzv_backend: static
  - same_network: True
  - main_training_function: main
  - fsdp_config: {
    'fsdp_auto_wrap_policy': 'TRANSFORMER_BASED_WRAP',
    'fsdp_backward_prefetch_policy': 'NO_PREFETCH',
    'fsdp_forward_prefetch': False,
    'fsdp_offload_params': False,
    'fsdp_sharding_strategy': 1,
    'fsdp_state_dict_type': 'SHARDED_STATE_DICT',
    'fsdp_sync_module_states': True,
    'fsdp_transformer_layer_cls_to_wrap': 'LlamaDecoderLayer',
    'fsdp_use_orig_params': True
  }
  - downcast_bf16: no
  - tpu_use_cluster: False
  - tpu_use_sudo: False
  - tpu_env: []

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

You will need ~71000MiB of VRAM for this experiment (e.g. 2xA40).

Setup our fork of qlora (which adds support for keeping the model in fp32, for mixed-precision full-finetunes, and adds a memory callback):

git clone https://github.com/scottlogic-alex/qlora.git
cd qlora
pip install -r requirements.txt

Run training for a couple of steps to see the memory measurements. The memory is reported via HF transformers Trainer's on_step_end callback, which occurs after all gradient accumulation microsteps have completed.

mixed fp32/bf16, 2 microsteps gradient accumulation (68.7GiB):

ACCELERATE_MIXED_PRECISION=bf16 python -m qlora \
--device_map_auto \
--disable_tqdm True \
--model_name_or_path huggyllama/llama-7b \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 2 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
step 1
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 35034MiB / 49140MiB
    Device 1: Used 35356MiB / 49140MiB
    Overall: Used 70391MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 34024MiB (Allocated: 12900MiB, Reserved 21123MiB)
    Device 1: Used 34346MiB (Allocated: 12900MiB, Reserved 21445MiB)
    Overall: Used 68370MiB (Allocated: 25801MiB, Reserved 42568MiB)

mixed fp32/bf16, no gradient accumulation (59.0GiB):

ACCELERATE_MIXED_PRECISION=bf16 python -m qlora \
--device_map_auto \
--disable_tqdm True \
--model_name_or_path huggyllama/llama-7b \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 1 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
step 1
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 30232MiB / 49140MiB
    Device 1: Used 30220MiB / 49140MiB
    Overall: Used 60453MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 29222MiB (Allocated: 12900MiB, Reserved 16321MiB)
    Device 1: Used 29210MiB (Allocated: 12900MiB, Reserved 16309MiB)
    Overall: Used 58432MiB (Allocated: 25801MiB, Reserved 32630MiB)

full-fp32, 2 microsteps gradient accumulation (53.4GiB):

python -m qlora \
--device_map_auto \
--disable_tqdm True \
--model_name_or_path huggyllama/llama-7b \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 2 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
step 1
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 27446MiB / 49140MiB
    Device 1: Used 27274MiB / 49140MiB
    Overall: Used 54721MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 26442MiB (Allocated: 12900MiB, Reserved 13541MiB)
    Device 1: Used 26270MiB (Allocated: 12900MiB, Reserved 13369MiB)
    Overall: Used 52712MiB (Allocated: 25801MiB, Reserved 26910MiB)

full-fp32, no gradient accumulation (52.3GiB):

python -m qlora \
--device_map_auto \
--disable_tqdm True \
--model_name_or_path huggyllama/llama-7b \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 1 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
step 1
  NVML memory stats (used+reserved, all processes):
    Device 0: Used 26772MiB / 49140MiB
    Device 1: Used 26772MiB / 49140MiB
    Overall: Used 53545MiB / 98280MiB
  Torch memory stats (allocated, reserved):
    Device 0: Used 25768MiB (Allocated: 12900MiB, Reserved 12867MiB)
    Device 1: Used 25768MiB (Allocated: 12900MiB, Reserved 12867MiB)
    Overall: Used 51536MiB (Allocated: 25801MiB, Reserved 25734MiB)

In summary:
with mixed fp32/bf16: enabling gradient accumulation increased VRAM usage 59.0->68.7GiB (that's 9.7GiB, or 1.54 bytes/param).
with pure-fp32: enabling gradient accumulation increased VRAM usage 52.3->53.4GiB only, which is not nothing (0.9GiB), but closer to what we'd expect (i.e. zero-cost).

side-note: it is also surprising that mixed fp32/bf16 costs 6.7GiB (1.075 bytes/param) more VRAM than pure-fp32.
in mixed-precision: we spend 2 bytes/param to make a half-precision compute copy, but we're supposed to get this back and then some, because our gradients get 2 bytes/param smaller, and our activations get smaller too (admittedly this is a batch-of-1, sequence-of-8 with checkpointing enabled, so our activations are not as significant). but mixed-precision is supposed to be smaller, not 1 byte/param bigger than full-fp32.

Expected behavior

in mixed fp32/bf16 mode: enabling gradient accumulation should cost no extra VRAM. or at least only have the same amount of overhead as it does in full-fp32 mode.

also: mixed fp32/bf16 training should not cost more VRAM than full-fp32 training.

@muellerzr muellerzr self-assigned this Oct 6, 2023
@BenjaminBossan
Copy link
Member

Small update: As I don't hardware to replicate the original issue 1:1, I ran a simplified version based on nlp_example.py. There, I couldn't observe the issue with mixed precision + gradient accumulation. So whatever it is, it has to be something that is different in the described setting (which, admittedly, could still be many things).

@Birch-san
Copy link

@BenjaminBossan here's a smaller repro with Pythia 1.4b. it should fit into <16GiB of VRAM. if that's still too big, then it can be tried with smaller Pythia models, but the smaller it goes, the harder it will be to tell the difference between fixed overheads versus per-param overheads.

To test this smaller-footprint repro on my 2xA40: I used CUDA_VISIBLE_DEVICES=1 to simulate having a single-GPU machine, and I removed --device_map_auto, since there's only one device. I pushed a new commit to the memor callback on that qlora stepwise branch, to support ignoring hidden devices.

mixed fp32/bf16, 2 microsteps gradient accumulation (14.8GiB NVML usage):

ACCELERATE_MIXED_PRECISION=bf16 python -m qlora \
--disable_tqdm True \
--model_name_or_path EleutherAI/pythia-1.4b-deduped \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 2 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
step 1
  NVML memory stats (used+reserved, all processes):
    Device 1: Used 15104MiB / 49140MiB
  Torch memory stats (allocated, reserved):
    Device 1: Used 14094MiB (Allocated: 5508MiB, Reserved 8585MiB)

mixed fp32/bf16, no gradient accumulation (13.6GiB NVML usage):

ACCELERATE_MIXED_PRECISION=bf16 python -m qlora \
--disable_tqdm True \
--model_name_or_path EleutherAI/pythia-1.4b-deduped \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 1 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
step 1
  NVML memory stats (used+reserved, all processes):
    Device 1: Used 13944MiB / 49140MiB
  Torch memory stats (allocated, reserved):
    Device 1: Used 12934MiB (Allocated: 5508MiB, Reserved 7425MiB)

That's a difference of (15104-13944)*1024**2/1414541312=0.86 bytes/param.
Which is smaller than Llama 7b's (70391-60453)*1024**2/6738423808=1.55 bytes/param.
it looks like far from being a variable overhead per param, perhaps it's even a variable overhead which worsens with scale? I'm not sure how that could happen.

full fp32, 2 microsteps gradient accumulation (12.0GiB NVML usage):

python -m qlora \
--disable_tqdm True \
--model_name_or_path EleutherAI/pythia-1.4b-deduped \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 2 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
step 1
  NVML memory stats (used+reserved, all processes):
    Device 1: Used 12320MiB / 49140MiB
  Torch memory stats (allocated, reserved):
    Device 1: Used 11316MiB (Allocated: 5508MiB, Reserved 5807MiB)

full fp32, no gradient accumulation (11.6GiB NVML usage):

python -m qlora \
--disable_tqdm True \
--model_name_or_path EleutherAI/pythia-1.4b-deduped \
--dataset prm800k-solutions \
--dataset_format prm800k-solutions \
--simulate_worst_case_seq_len \
--truncate_toward_center \
--source_max_len 4 \
--target_max_len 4 \
--gradient_accumulation_steps 1 \
--optim sgd \
--measure_memory \
--terminate_after_step 3 \
--bits 32 \
--full_finetune
step 1
  NVML memory stats (used+reserved, all processes):
    Device 1: Used 11924MiB / 49140MiB
  Torch memory stats (allocated, reserved):
    Device 1: Used 10920MiB (Allocated: 5508MiB, Reserved 5411MiB)

That's a difference of (12320-11924)*1024**2/1414541312=0.29 bytes/param.
Which is bigger than Llama 7b's (54721-53545)*1024**2/6738423808=0.18 bytes/param.
If it's an overhead that's more like a fixed overhead than a per-param overhead: it makes sense that a fixed overhead would be a larger proportion of a small model.

@muellerzr
Copy link
Collaborator

muellerzr commented Oct 6, 2023

@Birch-san still working on this, are you adjusting the batch size during this? (it's a little unclear)

E.g. with gradient_accumulation_steps=1 batch_size=32, gradient_accumulation_steps=2, batch_size=16?

@Birch-san
Copy link

I'm using Python 3.11.2, as stated in the system info in the issue.

I am not adjusting the batch size to compensate. I am always leaving batch size at default (per_device_train_batch_size=1). I understand that ordinarily you would adjust it, but for these memory benchmarks I am not.

@muellerzr
Copy link
Collaborator

muellerzr commented Oct 6, 2023

You must because otherwise we're not measuring the effective batch sizes properly here. For example, this would be the equivalent of me saying that a bs of 16 has less memory than a bs of 32 effectively, which is true. Why are we not doing it for this?

(if I'm reading/understanding wrong, that's okay)

@muellerzr
Copy link
Collaborator

muellerzr commented Oct 6, 2023

Working on being able to setup my environment to properly reproduce and I'll have more comments as I'm able to play with the code

@Birch-san
Copy link

I thought the point of gradient accumulation is that (half-precision) gradients from each microstep are accumulated into the same (full-precision?) buffer, so it doesn't matter how many microsteps of gradient accumulation you perform: peak memory usage doesn't change. you only need enough memory to survive the microstep.

this would be the equivalent of me saying that a bs of 16 has less memory than a bs of 32 effectively, which is true.

activations scale with batch size, but again this is just a concern for the peak memory usage within the microstep. the size of the weight update doesn't get bigger; you need an update per param, not per sample.

after you compute a microbatch of 16 and accumulate the gradients into a "weight update" buffer: no further memory is required to compute another microbatch of 16. every buffer you used for the first microstep can be re-used.

@muellerzr
Copy link
Collaborator

That makes sense, yes. Sorry blanked on this.

Still working on getting access to compute to run your code in bf16, however as Benjamin stated, using raw gradient accumulation didn't see this. (I saw this as well when running the gradient_accumulation script with some modification). My inclination as a result is this doesn't stem from how we do gradient accumulation but instead something with perhaps how peft is doing something with the gradients? Bare minimum: this doesn't stem from accelerate directly it's something about this combination. cc @pacman100

@muellerzr
Copy link
Collaborator

Or @younesbelkada

@Birch-san
Copy link

@muellerzr if bf16 is a blocker: I confirm that I get the same numbers using mixed fp32/fp16 too.

Copy link

github-actions bot commented Nov 5, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants