Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] OOM on finetuning vicuna-7b llava model on 4*A800 80G, anything wrong with my cfg? #394

Closed
ldfandian opened this issue Aug 25, 2023 · 17 comments

Comments

@ldfandian
Copy link

Question

Thanks for the great work~

Also, it looks like A800 cannot enable flash-attn. (error screenshot below)

python \
    llava/train/train.py \
    --model_name_or_path /root/devroot/models/vicuna-7b-v1.3 \
    --version v1 \
    --data_path /root/devroot/datasets/llava_instruct/llava_instruct_150k.json \
    --image_folder /root/devroot/datasets/llava_instruct/coco-train2017/ \
    --vision_tower openai/clip-vit-large-patch14 \
    --pretrain_mm_mlp_adapter ./checkpoints/vicuna-7b-pretrain/mm_projector.bin \
    --mm_vision_select_layer -1 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --bf16 True \
    --output_dir ./checkpoints/vicuna-7b-finetune \
    --num_train_epochs 3 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 32 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb
image image image
@haotian-liu
Copy link
Owner

It seems that A800 is supported in flash-attn. Please try this script with flash attention and deepspeed.

@ldfandian
Copy link
Author

It seems that A800 is supported in flash-attn. Please try this script with flash attention and deepspeed.

Yeah, I tried. Error mesage is the same~

@haotian-liu
Copy link
Owner

Can you provide the version of your flash attention, transformers, accelerate, and pytorch? Also, is flash attention compiled with the same CUDA version as PyTorch?

@ldfandian
Copy link
Author

ldfandian commented Aug 26, 2023

flash attention

I followed the instruction from the project page.

Also, can you please tell if flash-attn and deepspeed is a MUST to do finetune on 4 * A800 80G?
What if I upgrade the machine to use 8 * A800, will it be able to run w/o flash-attn and deepspeed?

> root / conda-environment.yaml
name: llava
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.05.30=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.10=h7f8727e_2
  - python=3.10.12=h955ad1f_0
  - readline=8.2=h5eee18b_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - xz=5.4.2=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
    - accelerate==0.21.0
    - aiofiles==23.2.1
    - aiohttp==3.8.5
    - aiosignal==1.3.1
    - altair==5.0.1
    - anyio==3.7.1
    - appdirs==1.4.4
    - async-timeout==4.0.3
    - attrs==23.1.0
    - bitsandbytes==0.41.0
    - certifi==2023.7.22
    - charset-normalizer==3.2.0
    - click==8.1.7
    - cmake==3.27.2
    - contourpy==1.1.0
    - cycler==0.11.0
    - deepspeed==0.9.5
    - docker-pycreds==0.4.0
    - einops==0.6.1
    - einops-exts==0.0.4
    - exceptiongroup==1.1.3
    - fastapi==0.101.1
    - ffmpy==0.3.1
    - filelock==3.12.2
    - flash-attn==2.0.9
    - fonttools==4.42.1
    - frozenlist==1.4.0
    - fsspec==2023.6.0
    - gitdb==4.0.10
    - gitpython==3.1.32
    - gradio==3.35.2
    - gradio-client==0.2.9
    - h11==0.14.0
    - hjson==3.1.0
    - httpcore==0.17.3
    - httpx==0.24.0
    - huggingface-hub==0.16.4
    - idna==3.4
    - jinja2==3.1.2
    - joblib==1.3.2
    - jsonschema==4.19.0
    - jsonschema-specifications==2023.7.1
    - kiwisolver==1.4.5
    - linkify-it-py==2.0.2
    - lit==16.0.6
    - llava==1.0.1
    - markdown-it-py==2.2.0
    - markdown2==2.4.10
    - markupsafe==2.1.3
    - matplotlib==3.7.2
    - mdit-py-plugins==0.3.3
    - mdurl==0.1.2
    - mpmath==1.3.0
    - multidict==6.0.4
    - networkx==3.1
    - ninja==1.11.1
    - numpy==1.25.2
    - nvidia-cublas-cu11==11.10.3.66
    - nvidia-cuda-cupti-cu11==11.7.101
    - nvidia-cuda-nvrtc-cu11==11.7.99
    - nvidia-cuda-runtime-cu11==11.7.99
    - nvidia-cudnn-cu11==8.5.0.96
    - nvidia-cufft-cu11==10.9.0.58
    - nvidia-curand-cu11==10.2.10.91
    - nvidia-cusolver-cu11==11.4.0.1
    - nvidia-cusparse-cu11==11.7.4.91
    - nvidia-nccl-cu11==2.14.3
    - nvidia-nvtx-cu11==11.7.91
    - orjson==3.9.5
    - packaging==23.1
    - pandas==2.0.3
    - pathtools==0.1.2
    - peft==0.4.0
    - pillow==10.0.0
    - pip==23.2.1
    - protobuf==4.24.1
    - psutil==5.9.5
    - py-cpuinfo==9.0.0
    - pydantic==1.10.12
    - pydub==0.25.1
    - pygments==2.16.1
    - pyparsing==3.0.9
    - python-dateutil==2.8.2
    - python-multipart==0.0.6
    - pytz==2023.3
    - pyyaml==6.0.1
    - referencing==0.30.2
    - regex==2023.8.8
    - requests==2.31.0
    - rpds-py==0.9.2
    - safetensors==0.3.3
    - scikit-learn==1.2.2
    - scipy==1.11.2
    - semantic-version==2.10.0
    - sentencepiece==0.1.99
    - sentry-sdk==1.29.2
    - setproctitle==1.3.2
    - setuptools==68.0.0
    - shortuuid==1.0.11
    - six==1.16.0
    - smmap==5.0.0
    - sniffio==1.3.0
    - starlette==0.27.0
    - svgwrite==1.4.3
    - sympy==1.12
    - threadpoolctl==3.2.0
    - timm==0.6.13
    - tokenizers==0.13.3
    - toolz==0.12.0
    - torch==2.0.1
    - torchvision==0.15.2
    - tqdm==4.66.1
    - transformers==4.31.0
    - triton==2.0.0
    - typing-extensions==4.7.1
    - tzdata==2023.3
    - uc-micro-py==1.0.2
    - urllib3==2.0.4
    - uvicorn==0.23.2
    - wandb==0.15.8
    - wavedrom==2.0.3.post3
    - websockets==11.0.3
    - wheel==0.38.4
    - yarl==1.9.2
prefix: /root/miniconda3/envs/llava

@haotian-liu
Copy link
Owner

haotian-liu commented Aug 26, 2023

Can you try pip install flash-attn==2.0.4 --no-build-isolation?

I find that all my current versions are 2.0.4.

Also, flash attention is necessary. You may do gradient accumutation with bs4xaccu4, which can make 7B fit in 8x A100s (maybe 4x as well). But flash-attention brings at least 2x speedup in my experiments. So spending some time to make flash-attn work should worth it.

@ldfandian
Copy link
Author

Get~ Thanks.

I don't have the box now. Will report the result back.

@ldfandian
Copy link
Author

Can you try pip install flash-attn==2.0.4 --no-build-isolation?

I find that all my current versions are 2.0.4.

Also, flash attention is necessary. You may do gradient accumutation with bs4xaccu4, which can make 7B fit in 8x A100s (maybe 4x as well). But flash-attention brings at least 2x speedup in my experiments. So spending some time to make flash-attn work should worth it.

BTW, even I make bs1xaccu16, it was still OOM... so that it looks like flash attn is a MUST.

@ldfandian
Copy link
Author

ldfandian commented Aug 26, 2023

Can you try pip install flash-attn==2.0.4 --no-build-isolation?

I find that all my current versions are 2.0.4.

Also, flash attention is necessary. You may do gradient accumutation with bs4xaccu4, which can make 7B fit in 8x A100s (maybe 4x as well). But flash-attention brings at least 2x speedup in my experiments. So spending some time to make flash-attn work should worth it.

I tried... no luck. still the same error:

14   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
15     return forward_call(*args, **kwargs)
16   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
17     outputs = self.parallel_apply(replicas, inputs, kwargs)
18   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
19     return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
20   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
21     output.reraise()
22   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
23     raise exception
24 RuntimeError: Caught RuntimeError in replica 0 on device 0.
25 Original Traceback (most recent call last):
26   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
27     output = module(*input, **kwargs)
28   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
29     return forward_call(*args, **kwargs)
30   File "/root/devroot/src/LLaVA/llava/model/language_model/llava_llama.py", line 78, in forward
31     outputs = self.model(
32   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
33     return forward_call(*args, **kwargs)
34   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 685, in forward
35     layer_outputs = torch.utils.checkpoint.checkpoint(
36   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
37     return CheckpointFunction.apply(function, preserve, *args)
38   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
39     return super().apply(*args, **kwargs)  # type: ignore[misc]
40   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
41     outputs = run_function(*args)
42   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 681, in custom_forward
43     return module(*inputs, output_attentions, None)
44   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
45     return forward_call(*args, **kwargs)
46   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
47     hidden_states, self_attn_weights, present_key_value = self.self_attn(
48   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
49     return forward_call(*args, **kwargs)
50   File "/root/devroot/src/LLaVA/llava/train/llama_flash_attn_monkey_patch.py", line 92, in forward
51     output_unpad = flash_attn_unpadded_qkvpacked_func(
52   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 406, in flash_attn_varlen_qkvpacked_func
53     return FlashAttnVarlenQKVPackedFunc.apply(
54   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
55     return super().apply(*args, **kwargs)  # type: ignore[misc]
56   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 123, in forward
57     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
58   File "/root/miniconda3/envs/llava/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_varlen_forward
59     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
60 RuntimeError: FlashAttention only support fp16 and bf16 data type
deepspeed llava/train/train_mem.py \
    --model_name_or_path /root/devroot/models/vicuna-7b-v1.3 \
    --version v1 \
    --data_path /root/devroot/datasets/llava/llava_instruct_150k.json \
    --image_folder /root/devroot/datasets/llava/train2017/ \
    --vision_tower openai/clip-vit-base-patch16 \
    --pretrain_mm_mlp_adapter ./checkpoints/vicuna-7b-pretrain/mm_projector.bin \
    --mm_vision_select_layer -1 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --bf16 True \
    --output_dir ./checkpoints/vicuna-7b-finetune \
    --num_train_epochs 3 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 12 \
    --lazy_preprocess True \
    --report_to wandb
name: llava
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.05.30=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.10=h7f8727e_2
  - python=3.10.12=h955ad1f_0
  - readline=8.2=h5eee18b_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - xz=5.4.2=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
    - accelerate==0.21.0
    - aiofiles==23.2.1
    - aiohttp==3.8.5
    - aiosignal==1.3.1
    - altair==5.0.1
    - anyio==3.7.1
    - appdirs==1.4.4
    - async-timeout==4.0.3
    - attrs==23.1.0
    - bitsandbytes==0.41.0
    - certifi==2023.7.22
    - charset-normalizer==3.2.0
    - click==8.1.7
    - cmake==3.27.2
    - contourpy==1.1.0
    - cycler==0.11.0
    - deepspeed==0.9.5
    - docker-pycreds==0.4.0
    - einops==0.6.1
    - einops-exts==0.0.4
    - exceptiongroup==1.1.3
    - fastapi==0.102.0
    - ffmpy==0.3.1
    - filelock==3.12.2
    - flash-attn==2.0.4
    - fonttools==4.42.1
    - frozenlist==1.4.0
    - fsspec==2023.6.0
    - gitdb==4.0.10
    - gitpython==3.1.32
    - gradio==3.35.2
    - gradio-client==0.2.9
    - h11==0.14.0
    - hjson==3.1.0
    - httpcore==0.17.3
    - httpx==0.24.0
    - huggingface-hub==0.16.4
    - idna==3.4
    - jinja2==3.1.2
    - joblib==1.3.2
    - jsonschema==4.19.0
    - jsonschema-specifications==2023.7.1
    - kiwisolver==1.4.5
    - linkify-it-py==2.0.2
    - lit==16.0.6
    - llava==1.0.1
    - markdown-it-py==2.2.0
    - markdown2==2.4.10
    - markupsafe==2.1.3
    - matplotlib==3.7.2
    - mdit-py-plugins==0.3.3
    - mdurl==0.1.2
    - mpmath==1.3.0
    - multidict==6.0.4
    - networkx==3.1
    - ninja==1.11.1
    - numpy==1.25.2
    - nvidia-cublas-cu11==11.10.3.66
    - nvidia-cuda-cupti-cu11==11.7.101
    - nvidia-cuda-nvrtc-cu11==11.7.99
    - nvidia-cuda-runtime-cu11==11.7.99
    - nvidia-cudnn-cu11==8.5.0.96
    - nvidia-cufft-cu11==10.9.0.58
    - nvidia-curand-cu11==10.2.10.91
    - nvidia-cusolver-cu11==11.4.0.1
    - nvidia-cusparse-cu11==11.7.4.91
    - nvidia-nccl-cu11==2.14.3
    - nvidia-nvtx-cu11==11.7.91
    - orjson==3.9.5
    - packaging==23.1
    - pandas==2.0.3
    - pathtools==0.1.2
    - peft==0.4.0
    - pillow==10.0.0
    - pip==23.2.1
    - protobuf==4.24.2
    - psutil==5.9.5
    - py-cpuinfo==9.0.0
    - pydantic==1.10.12
    - pydub==0.25.1
    - pygments==2.16.1
    - pyparsing==3.0.9
    - python-dateutil==2.8.2
    - python-multipart==0.0.6
    - pytz==2023.3
    - pyyaml==6.0.1
    - referencing==0.30.2
    - regex==2023.8.8
    - requests==2.31.0
    - rpds-py==0.9.2
    - safetensors==0.3.3
    - scikit-learn==1.2.2
    - scipy==1.11.2
    - semantic-version==2.10.0
    - sentencepiece==0.1.99
    - sentry-sdk==1.29.2
    - setproctitle==1.3.2
    - setuptools==68.0.0
    - shortuuid==1.0.11
    - six==1.16.0
    - smmap==5.0.0
    - sniffio==1.3.0
    - starlette==0.27.0
    - svgwrite==1.4.3
    - sympy==1.12
    - threadpoolctl==3.2.0
    - timm==0.6.13
    - tokenizers==0.13.3
    - toolz==0.12.0
    - torch==2.0.1
    - torchvision==0.15.2
    - tqdm==4.66.1
    - transformers==4.31.0
    - triton==2.0.0
    - typing-extensions==4.7.1
    - tzdata==2023.3
    - uc-micro-py==1.0.2
    - urllib3==2.0.4
    - uvicorn==0.23.2
    - wandb==0.15.8
    - wavedrom==2.0.3.post3
    - websockets==11.0.3
    - wheel==0.38.4
    - yarl==1.9.2
prefix: /root/miniconda3/envs/llava

@ldfandian
Copy link
Author

ldfandian commented Aug 26, 2023

i don't provide a " --deepspeed /path/to/deepspeed.json " in the run script. that should be ok, right?

and, the pretrain ckpt comes from a non-flash-attn train.py instead of a flash-attn train_mem.py. should I re-pretrain everything w/ flash-attn enabled from a scratch?

@haotian-liu
Copy link
Owner

i don't provide a " --deepspeed /path/to/deepspeed.json " in the run script. that should be ok, right?

That is not okay. Please use zero3.json or zero2.

and, the pretrain ckpt comes from a non-flash-attn train.py instead of a flash-attn train_mem.py. should I re-pretrain everything w/ flash-attn enabled from a scratch?

That is not needed. it is a linear layer, so it will be fine.

@ldfandian
Copy link
Author

ldfandian commented Aug 26, 2023

i don't provide a " --deepspeed /path/to/deepspeed.json " in the run script. that should be ok, right?

That is not okay. Please use zero3.json or zero2.

and, the pretrain ckpt comes from a non-flash-attn train.py instead of a flash-attn train_mem.py. should I re-pretrain everything w/ flash-attn enabled from a scratch?

That is not needed. it is a linear layer, so it will be fine.

WoW! deepspeed w/ zero3.json works great~

Thanks for all the quick responses and your amazing work.

@ldfandian
Copy link
Author

By using deepspeed w/ zero3.json, on my 8 * A100 80G, it takes ~15 hours for a full funetine (3 epochs).
And, it takes only <35GB GPU memory for each A800 with a full (~90% GPU-util).

Is the speed expected?

image

@haotian-liu
Copy link
Owner

3699 iters is 3 epochs already. So you do not need to multiply by 3. ~6 hours is expected.

@ldfandian
Copy link
Author

3699 iters is 3 epochs already. So you do not need to multiply by 3. ~6 hours is expected.

BTW, I find there is a considerable loss drop at the begining of each epoch. So, the reason we limit to 3 epoches is to prevent from too much overfit, right?

W B Chart 2023_8_26 16_55_05

@ldfandian
Copy link
Author

3699 iters is 3 epochs already. So you do not need to multiply by 3. ~6 hours is expected.

BTW, I find there is a considerable loss drop at the begining of each epoch. So, the reason we limit to 3 epoches is to prevent from too much overfit, right?

W B Chart 2023_8_26 16_55_05

@haotian-liu

Also, I tried LLaVA, blip2-flant5-xl/xxl, instructblip-vicuna7b and found that LLaVA works best for the photos taken by my iphone. May I take this conclusion away? :

  1. the quality of train data >>> 2. complexity of the architecture (Linear Layer vs QFormer) >>> 3. count of vit/llm model parameters.
    (coz LLaVa clearly has considerably less trainable parameters to connect vit and llm, but performs very well).

And, how do you think of ResNet-50/101 as an image encoder? Will it perform similar like ViT?

@wizyoung
Copy link

wizyoung commented Sep 8, 2023

@ldfandian hi, I met the loss not converging well issue, can you post your full train log for me?

@wizyoung
Copy link

wizyoung commented Sep 8, 2023

@ldfandian Btw, are you using gradient checkpointing under deepspeed zero3? In my env here, it seems zero3 conflicts with checkpointing, but zero2 does not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants