Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slowdown in Training Speed Due to SDPA Mask Fix in Version 4.40.0 #30461

Closed
2 of 4 tasks
achew010 opened this issue Apr 24, 2024 · 3 comments
Closed
2 of 4 tasks

Slowdown in Training Speed Due to SDPA Mask Fix in Version 4.40.0 #30461

achew010 opened this issue Apr 24, 2024 · 3 comments
Labels

Comments

@achew010
Copy link

achew010 commented Apr 24, 2024

System Info

Hi,

I have been doing some peft tuning with Mistral/Mixtral and recently I observed a slowdown in training since the release of version 4.40.0. I narrowed it down to this fix in 40eb6d6 where the sliding window is now specified in _prepare_4d_causal_attention_mask_for_sdpa.

I ran a simple training job and the training statistics produced 2 different sets of throughputs

Sequence Length release 4.39.3 (toks/s) release 4.40.0 (toks/s)
4096 3247 2483
8192 3083 1918

When my training sequence length is within/on the sliding window threshold (i.e. seqlen = 4096, window = 4096), it should fall back to the SDPA kernel to handle the causal mask. I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).

Below is a dummy example showing that simply not passing the causal mask into pytorch's SDPA function (allowing the kernel to handle the causal mask itself) vs specifying the sliding window, has a significant impact on the processing speed of the kernel.

Causal Mask Attn Mask is passed to Torch SDPA Causal Mask handled internally in Torch SDPA

Is this slowdown something we should expect from using the SDPA module with the current fix?

I attached a simple script to reproduce the issue

System Info

- `transformers` version: 4.40.0
- Platform: Linux-4.18.0-372.71.1.el8_6.x86_64-x86_64-with-glibc2.31
- Python version: 3.10.8
- Huggingface_hub version: 0.22.2
- Safetensors version: 0.4.3
- Accelerate version: 0.29.3
- Accelerate config:    not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: True
- Using distributed or parallel set-up in script?: False

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Script to reproduce
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, __version__ as transformer_version
from datasets import load_dataset
from trl import SFTTrainer

print(f"transformers version: {transformer_version}")

dataset = load_dataset("yahma/alpaca-cleaned", split="train")
model_name = 'mistralai/Mistral-7B-v0.1'

config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token});
tokenizer.pad_token = tokenizer.unk_token

model = AutoModelForCausalLM.from_config(
    config, torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
    attn_implementation="sdpa",
)

print(model.model._attn_implementation)

args = {
    'batch_size': 4,
    'gradient_accumulation_steps': 1,
    'use_gradient_checkpointing': 1,
    'warmup_steps': 10,
    'lr': 2e-4,
    'logging_steps': 10,
    'output_dir': './results',
    'optimizer': 'adamw_torch',
    'weight_decay': 0.0,
    'lr_scheduler': 'linear',
    'seed': 42,
    'max_steps': 100,
    'context_length': 4096,
}

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n\n"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:\n\n"
    ),
}

def formatting_prompts_func(example):
    output_texts = []
    if example.get("input", "") == "":
        prompt = PROMPT_DICT["prompt_no_input"].format_map(example)
    else:
        prompt = PROMPT_DICT["prompt_input"].format_map(example)
    new_example = prompt + example["output"]
    return new_example

training_args = TrainingArguments(
    per_device_train_batch_size = args['batch_size'],
    gradient_accumulation_steps = args['gradient_accumulation_steps'],
    gradient_checkpointing=args['use_gradient_checkpointing'],
    warmup_steps = args['warmup_steps'],
    max_steps = args['max_steps'],
    learning_rate = args['lr'],
    logging_strategy = 'steps',
    logging_steps = args['logging_steps'],
    output_dir = args['output_dir'],
    optim = args['optimizer'],
    weight_decay = args['weight_decay'],
    lr_scheduler_type = args['lr_scheduler'],
    seed = args['seed'],
    include_tokens_per_second = True,
)

trainer = SFTTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = dataset,
        max_seq_length = args['context_length'],
        args = training_args,
        formatting_func=formatting_prompts_func,
        packing=True,
    )

stats = trainer.train()

Expected behavior

  1. Throughput should remain the same for sequence lengths lower than the window size for SPDA

  2. Throughput should be slightly faster (from lesser computations in local attention) than regular attention (when no sliding window is specified in causal mask) for longer sequence lengths

@achew010 achew010 changed the title Slowdown in Training Speed in Due to SDPA Mask Fix in Version 4.40.0 Slowdown in Training Speed Due to SDPA Mask Fix in Version 4.40.0 Apr 24, 2024
@amyeroberts
Copy link
Collaborator

cc @fxmarty

@fxmarty
Copy link
Collaborator

fxmarty commented Apr 26, 2024

Hi @achew010, thank you for the report. Two PRs may be at play here, #30127 and #30070. Long story short,

  • SDPA requires attn_mask=None to be able to dispatch on its FA2 backend.
  • the implementation of attention with SDPA when using sliding window used to be incorrect, not using sliding window at all due to the mask being dropped (for the above reason). Fix SDPA sliding window compatibility #30127 ensures the correctness of the sliding mask and does not rely on SDPA's is_causal argument.

As you can see here, a check is done on key_value_length < sliding_window, that still allows to ignore the mask for some sequence lengths.

I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).

This was already the case for transformers<=4.39 (and also in Mistral public code release). Unfortunately, apart from the original HazyResearch/flash-attn implementation (attn_implementation="flast_attention_2" in Transformers, see the doc), the is no efficient implementation for eager & SDPA currently. I know Driss from PyTorch was working on this.

It is still unclear to me why you see this regression, given that transformers==4.39 used to always use attn_mask argument (= never FA2, see here and here in 4.39). To me the more likely culprit is that PyTorch now picks for your FA2 instead of mem-efficient attention and somehow FA2 is slower for you. Could you try playing with the decorator https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel and report here?

@achew010
Copy link
Author

achew010 commented Apr 29, 2024

Thanks for giving some context @fxmarty, from your explanation i have a better understanding of what caused the regression.

To me the more likely culprit is that PyTorch now picks for your FA2 instead of mem-efficient attention and somehow FA2 is slower for you. Could you try playing with the decorator https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel and report here?

Indeed like what you said, by playing around with this context manager and setting them to use the same backend (only mem-efficient_attention=True), i was able to match their speeds. Seems that Pytorch is choosing what backend to use based on the presence of a custom mask afterall. If no custom mask is passed in, it will choose to use the faster FA2 backend. (See Below)

Causal Mask Attn Mask is passed to Torch SDPA Causal Mask handled internally in Torch SDPA

It is still unclear to me why you see this regression, given that transformers==4.39 used to always use attn_mask argument (= never FA2, see here and here in 4.39)

The introduction of a sliding window here will influence what backend the SDPA kernel will use. In my setup, having a max_context_length=4096 same as the sliding_window=4096 will cause this check to set ignore_causal_mask=False. This will then skip to the code to produce a custom attention mask here to be passed to the SDPA kernel. By running my example script and setting the context length to be smaller than the sliding window value e.g. <=4095 i will avoid the generation of a custom mask and subsequently let Pytorch SDPA use the faster FA2 backend, by doing so i am able to replicate the throughputs i saw in 4.39.

This clarifies everything, thanks alot for the help!

@fxmarty fxmarty closed this as completed Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants