Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FA2 integration #28142

Merged
merged 8 commits into from
Dec 20, 2023
Merged

Fix FA2 integration #28142

merged 8 commits into from
Dec 20, 2023

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Dec 19, 2023

What does this PR do?

  1. Fix FA2 integration.

Issues with the current FA2 integration.

  1. It makes providing torch_dtype to the from_pretrained class method mandatory. This leads to the whole model being loaded in half-precision which leads to unstable training because it would result in pure half precision training instead of mixed-precision training. Please refer Mistral loss instability #26498 (comment) for more details.
    Currently, main branch throws below error when not passing half precision to torch_dtype which shouldn't be the case.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.

...

File /raid/sourab/transformers/src/transformers/modeling_utils.py:1422, in PreTrainedModel._check_and_enable_flash_attn_2(cls, config, torch_dtype, device_map, check_device_map, hard_check_only)
   1418     logger.warning(
   1419         "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
   1420     )
   1421 elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
-> 1422     raise ValueError(
   1423         f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to"
   1424         " unexpected behaviour."
   1425     )
   1427 # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
   1428 # or the model may be initialized under the context manager `with torch.device("cuda"):`.
   1429 if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":

ValueError: Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed torch.float32, this might lead to unexpected behaviour.
  1. As a workaround, one would pass torch_dtype, then recast the model to float32 and try to train but then end up getting error from Flash Attention library as given below:
File /raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:79, in _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, return_softmax)
     77 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     78 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 79 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
     80     q,
     81     k,
     82     v,
     83     None,
     84     cu_seqlens_q,
     85     cu_seqlens_k,
     86     max_seqlen_q,
     87     max_seqlen_k,
     88     dropout_p,
     89     softmax_scale,
     90     False,
     91     causal,
     92     window_size[0],
     93     window_size[1],
     94     return_softmax,
     95     None,
     96 )
     97 # if out.isnan().any() or softmax_lse.isnan().any():
     98 #     breakpoint()
     99 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: FlashAttention only support fp16 and bf16 data type
  1. Now, to overcome that, one would need to cast the trainable params to float32 and all the other params to float16, this is only possible with EPFT approaches. For normal fine-tuning, things end here leaving no way to use flash attention correctly. But this change, leads to unstable learning plateauing at high loss therefore no luck in PEFT setup too.

Screenshot 2023-12-20 at 12 03 36 AM

All these issues are being resolved by this PR. Notice the above graph with the before and after PR logs. With this PR, the loss is similar to the case when not using FA2.

pacman100 and others added 3 commits December 19, 2023 21:45
Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@pacman100 pacman100 marked this pull request as ready for review December 19, 2023 18:42
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the deep dive! As this is / was critical, this could be added to the Llama.md as a tip ? (nit)
Otherwise looks great. autocast feature, was introduced in PyTorch version 1.6.0 so no worries there

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
pacman100 and others added 2 commits December 20, 2023 13:22
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense as discussed offline ! Thanks very much for the deep dive @pacman100 !

@pacman100
Copy link
Contributor Author

this could be added to the Llama.md as a tip ? (nit)

Done.

@pacman100 pacman100 merged commit def581e into main Dec 20, 2023
21 checks passed
@pacman100 pacman100 deleted the smangrul/fix-fa2-integration branch December 20, 2023 08:55
@teknium1
Copy link

So FSDP is saved?

@younesbelkada
Copy link
Contributor

I think so, from the experiment @pacman100 shared with me you could load a transformers model with FA-2 and train it with autocast (fp16=True) and the model was converging nicely

@pacman100
Copy link
Contributor Author

Hello @teknium1, to re-confirm, I ran the below experiment on 8 80GB GPUs to finetune Mistral 7B for the SFT task on Ultrachat 200K (1 epoch).

Code: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/run_fsdp.sh
Config: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml
Versions:

- `transformers` version: 4.37.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.20.1
- Safetensors version: 0.4.1
- Accelerate version: 0.25.0.dev0
- Accelerate config: 	not found
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
- trl 0.7.8.dev0

Plots:
Screenshot 2023-12-26 at 5 58 38 PM

Observations:
Plot converges as expected similarly to the plot for Zephyr sft training plots

staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* fix fa2

* fix FA2 for popular models

* improve warning and add Younes as co-author

Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix the warning

* Add Tip

* typo fix

* nit

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@TJ-Solergibert
Copy link

Hi! I haven't been able to use Llama3-8B w/FA2. I'm running the following code:

from transformers import Trainer, TrainingArguments, LlamaForCausalLM
from typing import Dict
import numpy as np
from torch.utils.data import Dataset
import torch

class DummyDataset(Dataset):
    def __init__(self, num_samples: int, sequence_length: int) -> None:
        self.num_samples = num_samples
        self.sequence_length = sequence_length
    def __len__(self) -> int:
        return self.num_samples
    def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
        x = torch.LongTensor(np.random.randint(low= 0, high= 1000, size=(self.sequence_length+ 1)))
        return {"input_ids": x[:-1], "labels": x[1:]}

def main():
    path_to_model = "/mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct"
    output_dir = "/mloscratch/homes/solergib/simple/output"
    training_arguments = TrainingArguments(per_device_train_batch_size=1, gradient_checkpointing=True, bf16=True, max_steps=10, output_dir=output_dir)
    model = LlamaForCausalLM.from_pretrained(path_to_model, attn_implementation="flash_attention_2") # It's the default
    train_dataset = DummyDataset(10000000, 1024)
    trainer = Trainer(model=model, args=training_arguments, train_dataset=train_dataset)
    trainer.train()
if __name__ == "__main__":
    main()

And I get the same error complaining about the dtype:

File "/home/solergib/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 507, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  File "/home/solergib/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

My env:

- `transformers` version: 4.40.0
- Platform: Linux-6.5.0-26-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.22.2
- Safetensors version: 0.4.3
- Accelerate version: 0.30.1.dev0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.0a0+40ec155e58.nv24.03 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: Yes, 1x80GB A100
- Using distributed or parallel set-up in script?: NO, but tried with deep speed Zero3 and it happens the same
- flash-attn                2.5.5 (Tried with 2.5.7 and same error)

Thanks!

@ArthurZucker
Copy link
Collaborator

Hey! This is unrelated to transformers see this issue: Dao-AILab/flash-attention#822

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants