Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to train "bigcode/starcoder" model on 80 A100-80GB GPUs using FSDP #1864

Closed
dineshkh opened this issue Aug 19, 2023 · 10 comments
Closed
Labels
solved The bug or feature request has been solved, but the issue is still opened

Comments

@dineshkh
Copy link

dineshkh commented Aug 19, 2023

I am trying to further train bigcode/starcoder 15 billion parameter model with 8k context length using 80 A100-80GB GPUs (10 nodes and 8 GPUs on each node) using accelerate FSDP. I am using gradient checkpoint and my batch size per device is 1 only.
Even after using fsdp_backward_prefetch_policy: NO_PREFETCH and fsdp_offload_params: trueI am getting following error OOM error:

Traceback of TorchScript (most recent call last):
  File "<string>", line 83, in <backward op>
            result = torch.softmax(self, dim, dtype)
            def backward(grad_output):
                grad_self = torch._softmax_backward_data(grad_output, result, dim, self.dtype)
                            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                return grad_self, None, None
RuntimeError: CUDA out of memory. Tried to allocate 12.00 GiB (GPU 2; 79.15 GiB total capacity; 56.66 GiB already allocated; 9.54 GiB free; 68.85 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Following is my FSDP configuration:

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: NO_PREFETCH
  fsdp_offload_params: true
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: GPTBigCodeBlock
machine_rank: 0
main_process_ip: 172.168.12.19
main_process_port: 55930
main_training_function: main
mixed_precision: 'bf16'
num_machines: 10
num_processes: 80
use_cpu: false

I have tried with accelerate==0.20.3 and accelerate==0.21.0. My transformers and Pytorch version are following:

transformers 4.29.0 pypi_0 pypi
pytorch 2.0.1 py3.11_cuda11.7_cudnn8.5.0_0 pytorch

During the forward pass the memory on each GPU roughly goes up to 41.3GB and when backward pass starts it goes to 77 GB and then programs crashes.

Please let me know if I missing something here.
Thanks!

@sgugger
Copy link
Collaborator

sgugger commented Aug 21, 2023

cc @pacman100

@pacman100
Copy link
Contributor

pacman100 commented Aug 21, 2023

Hello, I am able to train Starcoder with 8K seq len on 16 A100 80GB GPUs (2 nodes each having 8 GPUs) + Gradient Checkpointing + Flash Attention V2 without any issues.

Code: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/code_assistant/training
Accelerate Config: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/code_assistant/training/configs/fsdp_config.yaml
SLURM Launcher:

#!/bin/bash
#SBATCH --job-name=ift_llama
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1          # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=96
#SBATCH --mem-per-cpu=11G # Important to enable "mix" use of GPUs across cluster users
#SBATCH --partition=XXXXX
#SBATCH --gres=gpu:8 # Adjust number of GPUs here
#SBATCH --output=shared_storage/sourab/temp/logs/%x-%j.out
#SBATCH --err=shared_storage/sourab/temp/logs/%x-%j.err

set -x -e

# CHANGE HERE THE CONDA EVN AND ANY STARTUP SCRIPTS
source ~/sourab/.bashrc
source shared_storage/sourab/miniconda3/etc/profile.d/conda.sh
conda activate hf
cd shared_storage/sourab/DHS-LLM-Workshop/code_assistant/training
git pull

# have the below in case of debugging nccl issues such as nccl timeout.
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO
# hide duplicated errors using this hack - will be properly fixed in pt-1.12
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json

# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1

echo "START TIME: $(date)"

# CHANGE TO CUMMULATIVELY LOG OUTPUTS
LOG_PATH="main_log.txt"

GPUS_PER_NODE=8
NNODES=$SLURM_NNODES
NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE)

# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000

# OTHER LAUNCHERS CAN BE USED HERE
export LAUNCHER="accelerate launch \
    --config_file configs/fsdp_config.yaml \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --num_processes $NUM_PROCESSES \
    --num_machines $NNODES \
    "
# Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable

export PROGRAM="\
train.py \
--model_name "bigcode/starcoder" \
--dataset_name "smangrul/code-chat-assistant-v1" \
--max_seq_len 8192 \
--bf16 True \
--num_train_epochs 2 \
--logging_steps 1 \
--packing True \
--output_dir "shared_storage/sourab/temp/starcoder-chat-asst" \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--dataset_text_field "content" \
--learning_rate 5e-5  \
--lr_scheduler_type "cosine" \
--weight_decay 0.01 \
--warmup_ratio 0.03 \ 
--use_flash_attn True \
--use_gradient_checkpointing True
"


export CMD="$LAUNCHER $PROGRAM"

srun --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $LOG_PATH

echo "END TIME: $(date)"

Output:

{'loss': 1.4429, 'learning_rate': 1.2195121951219514e-06, 'epoch': 0.0}
  5   8%|▊         | 102/1360 [26:20<6:06:17, 17.2023-08-21 08:09:04,551 _dedup_tensors.py:44 INFO p:MainProcess t:MainThread: Duplicate keys to remove: {}
  6 {'loss': 1.5454, 'learning_rate': 2.4390243902439027e-06, 'epoch': 0.0}
  7 {'loss': 1.3818, 'learning_rate': 3.6585365853658537e-06, 'epoch': 0.0}
  8 {'loss': 1.4468, 'learning_rate': 4.8780487804878055e-06, 'epoch': 0.01}
  9 {'loss': 1.5747, 'learning_rate': 6.0975609756097564e-06, 'epoch': 0.01}
 10 {'loss': 1.2974, 'learning_rate': 7.317073170731707e-06, 'epoch': 0.01}
 11 {'loss': 1.3887, 'learning_rate': 8.53658536585366e-06, 'epoch': 0.01}
 12 {'loss': 1.3201, 'learning_rate': 9.756097560975611e-06, 'epoch': 0.01}
 13 {'loss': 1.2383, 'learning_rate': 1.0975609756097562e-05, 'epoch': 0.01}
 14 {'loss': 1.2544, 'learning_rate': 1.2195121951219513e-05, 'epoch': 0.01}
 15 {'loss': 1.3003, 'learning_rate': 1.3414634146341466e-05, 'epoch': 0.02}
 16 {'loss': 1.1711, 'learning_rate': 1.4634146341463415e-05, 'epoch': 0.02}
 17 {'loss': 1.1685, 'learning_rate': 1.5853658536585366e-05, 'epoch': 0.02}
 18 {'loss': 1.3093, 'learning_rate': 1.707317073170732e-05, 'epoch': 0.02}
 19 {'loss': 1.1169, 'learning_rate': 1.8292682926829268e-05, 'epoch': 0.02}
 20 {'loss': 1.1589, 'learning_rate': 1.9512195121951222e-05, 'epoch': 0.02}
 21 {'loss': 1.3359, 'learning_rate': 2.073170731707317e-05, 'epoch': 0.03}
 22 {'loss': 1.092, 'learning_rate': 2.1951219512195124e-05, 'epoch': 0.03}
 23 {'loss': 1.1694, 'learning_rate': 2.3170731707317075e-05, 'epoch': 0.03}
 24 {'loss': 1.116, 'learning_rate': 2.4390243902439026e-05, 'epoch': 0.03}
 25 {'loss': 1.0955, 'learning_rate': 2.5609756097560977e-05, 'epoch': 0.03}
 26 {'loss': 1.2515, 'learning_rate': 2.682926829268293e-05, 'epoch': 0.03}
 27 {'loss': 1.3267, 'learning_rate': 2.8048780487804882e-05, 'epoch': 0.03}
 28 {'loss': 1.1934, 'learning_rate': 2.926829268292683e-05, 'epoch': 0.04}
 29 {'loss': 1.3142, 'learning_rate': 3.048780487804878e-05, 'epoch': 0.04}
 30 {'loss': 1.3301, 'learning_rate': 3.170731707317073e-05, 'epoch': 0.04}
 31 {'loss': 1.2852, 'learning_rate': 3.292682926829269e-05, 'epoch': 0.04}
 32 {'loss': 1.1448, 'learning_rate': 3.414634146341464e-05, 'epoch': 0.04}
 33 {'loss': 1.2456, 'learning_rate': 3.5365853658536584e-05, 'epoch': 0.04}
 34 {'loss': 1.1499, 'learning_rate': 3.6585365853658535e-05, 'epoch': 0.04}

@pacman100 pacman100 added the solved The bug or feature request has been solved, but the issue is still opened label Aug 21, 2023
@dineshkh
Copy link
Author

dineshkh commented Aug 21, 2023

Thanks @pacman100 for the reply.
Currently I am not using Flash Attention in my code.
Can you plese guid how I can incorporate Flash Attention V2 in my Code ?
Also do I need to change anything else in my code like dataloader or packing of pretraining data ?
My goal is do an extended pre-training not fine-tuning.

@pacman100
Copy link
Contributor

For extended pretraining refer the code: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/personal_copilot/training

The rest remains the same except for the dataset and sample creation which does <FIM_PREFIX> ... <FIM_SUFFIX> ... <FIM_MIDDLE> ... for continued pertaining.

The above code uses monkey patching for using Flash V2.

@dineshkh
Copy link
Author

dineshkh commented Aug 21, 2023

Thanks @pacman100.
Sorry for asking so many questions but I have few more queries.
I can train 8K seq len on 16 A100 80GB GPUs without PEFT or any quantization, right ?
Also, In my existing pretraing code if I add replace_starcoder_attn_with_flash_attn() before calling the model via AutoModelForCausalLM.from_pretrained then everything should be fine if I only want to use casual masking objective not FIM objective?
Also, without Flash V2 is it possible to train Starcoder using accelerate FSDP or accelerate DeepSpeed ?

@pacman100
Copy link
Contributor

pacman100 commented Aug 21, 2023

I can train 8K seq len on 16 A100 80GB GPUs without PEFT or any quantization, right ?

Yes, that is what I have tested above.

Also, In my existing pretraing code if I add replace_starcoder_attn_with_flash_attn() before calling the model via AutoModelForCausalLM.from_pretrained then everything should be fine ?

Yes, that should be fine. For installing Flash V2, refer: https://github.com/Dao-AILab/flash-attention/tree/main#installation-and-features

Also, without Flash V2 is it possible to train Starcoder using accelerate FSDP or accelerate DeepSpeed ?

No, It was leading to OOM with 8K seq len. DeepSpeed with CPU offloading might work, please test it out and share results with community.

File "/fsx/sourab/miniconda3/envs/hf/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 288, in backward
  File "/fsx/sourab/miniconda3/envs/hf/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 288, in backward
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/fsx/sourab/miniconda3/envs/hf/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/fsx/sourab/miniconda3/envs/hf/lib/python3.11/site-packages/torch/autograd/__init__.py", line 251, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)    
torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/fsx/sourab/miniconda3/envs/hf/lib/python3.11/site-packages/torch/autograd/__init__.py", line 251, in backward
  File "/fsx/sourab/miniconda3/envs/hf/lib/python3.11/site-packages/torch/autograd/__init__.py", line 251, in backward
        Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passtorch.autograd.backward(outputs_with_grad, args_with_grad)

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 12.00 GiB. GPU 4 has a total capacty of 79.35 GiB of which 2.88 GiB is free. Including non-PyTorch memory, this process has 76.21 GiB memory in use. Of the allocated memory 46.66 GiB is allocated by PyTorch, and 27.85 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Note:

  1. Flash V2 support that I have implemented above ignores padding/attention_mask/custom_mask. It is meant for continued pre-training with packing inputs to consume the entire sequence lengths.

@dineshkh
Copy link
Author

dineshkh commented Aug 21, 2023

Thanks @pacman100.
I will try with deepspeed and paste the results here.

Also, may be I didn't understood the last statement so asking again when I create packed sequences I add right padding in the last sequence (whose size is not exactly 8192) and I use batched=True so there will many such sequence (equal to no of threads) so should I throw them or you code will work ?

I am calling the forward function as following:

outputs = model(input_ids,
                        labels = input_ids,
                        attention_mask = attention_mask,
                        token_type_ids=None,
                        )

I hope that is fine ? I don't have to remove attention_mask from the above method call?

Also, your Flash V2 code work If I provide batches as list of tensors and I don't have to provide as list of list or anything like that. I am asking this because I some other support for Flash V2 which requires batches as list of list and without any padding.

@dineshkh
Copy link
Author

@pacman100 I am able to run train StarCoder-15B with batch size = 1 and 16 A100-80GB GPUs and 8k context lengths. I saw GPU memory utilization was less than 40 GB, I will try with increasing batch size also.

One more question ?

Can we also use BetterTransformer from Optimum with accelerate ? I think it currently has only support for Flash attention v1.

@prince14322
Copy link

How to use Flash-v2 for fine-tuning?
I can see that the above solution only works for pre-training?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved The bug or feature request has been solved, but the issue is still opened
Projects
None yet
Development

No branches or pull requests

4 participants