Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama uses significantly more memory in 4.38 & 4.39 than 4.37 with identical code #30010

Closed
3 of 4 tasks
warner-benjamin opened this issue Apr 3, 2024 · 6 comments · Fixed by #30442
Closed
3 of 4 tasks

Comments

@warner-benjamin
Copy link
Contributor

warner-benjamin commented Apr 3, 2024

System Info

Transformers 4.37.2, 4.38.2, & 4.39.3.
Python 3.11.
PyTorch 2.2.2 w/ Cuda 12.1

Who can help?

@ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

QLoRA Llama-7B with the SDPA Attention when ran with 4.37.2 on a 24GB card uses less memory than 4.38.2 and 4.39.3. You can reproduce with this script:

import argparse
from random import randint

from tqdm import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import Dataset
from peft import LoraConfig, get_peft_model

torch.set_float32_matmul_precision("high")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_sequence_length", type=int, default=1280)
    parser.add_argument("--dataset_size", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--flash_attn", action="store_true")
    parser.add_argument("--torch_compile", action="store_true")
    parser.add_argument("--profile_memory", action="store_true")

    return parser.parse_args()


def get_dataset(dataset_size, vocab_size, sequence_length):
    dataset = Dataset.from_dict(
        {"input_ids": [[randint(0, vocab_size) for _ in range(sequence_length)] for i in range(0, dataset_size)]}
    )
    return dataset


def data_collator(batch):
    batch = torch.stack([b["input_ids"] for b in batch])
    return {"input_ids": batch, "labels": batch}


def append_stats(batch, stats):
    if batch == 0:
        stats.append(torch.cuda.memory_reserved(0))


def main():
    args = parse_args()
    print(args)
    print(f"Transformers version: {transformers.__version__}")
    stats = []

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    peft_config = LoraConfig(
        r=16,
        lora_alpha=8,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )

    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        max_position_embeddings=args.max_sequence_length,
        attn_implementation="flash_attention_2" if args.flash_attn else "sdpa",
        torch_dtype="auto",
        use_cache=False,
        quantization_config=bnb_config,
    )

    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    optimizer = optim.SGD(model.parameters(), lr=0.001)

    dataset = get_dataset(
        args.dataset_size * args.batch_size,
        tokenizer.vocab_size,
        args.max_sequence_length,
    ).with_format("torch")

    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        collate_fn=data_collator,
    )

    if args.torch_compile:
        print("Compiling model")
        model = torch.compile(model)

    torch.cuda.reset_peak_memory_stats(0)
    for i, batch in enumerate(tqdm(dataloader)):
        input_ids = batch["input_ids"].to("cuda")
        labels = batch["labels"].to("cuda")

        if args.profile_memory and i == 0:
            torch.cuda.memory._record_memory_history()

        append_stats(i, stats)
        output = model(input_ids=input_ids, labels=labels, attention_mask=None)
        loss = output.loss

        append_stats(i, stats)
        loss.backward()

        append_stats(i, stats)
        optimizer.step()
        optimizer.zero_grad()

        if args.profile_memory and i == 0:
            torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
            torch.cuda.memory._record_memory_history(enabled=None)

    for label, stat in zip(["forward", "backward", "optimizer"], stats):
        print(f"Before {label}: {stat/2**30:.2f} GiB")
    print(f"Max Reserved: {torch.cuda.max_memory_reserved(0)/2**30:.2f} GiB")


if __name__ == "__main__":
    main()
# 4.37.2
python train.py --batch_size 1 --max_sequence_length 1280
python train.py --batch_size 1 --max_sequence_length 1536
# 4.39.3
python train.py --batch_size 1 --max_sequence_length 1280
# this one errors out on a 24GB card, likely due to OOM
python train.py --batch_size 1 --max_sequence_length 1536

Expected behavior

While spot checking our FSDP+QLoRA script on Transformers 4.39.3, I noticed that the maximum batch size we could finetune on two 24 GB cards with a sequence length of 2048 was reduced from 12 to 5 compared to 4.37.2. This is due to Llama 2 in 4.38 and 4.39 using significantly more memory than 4.37.

This Llama memory issue persists post #29753, which resolved some but not all of the Llama rewrite memory issues mentioned in #29484 and other issues.

I also reproduced the issue without FSDP on one 24GB card using the script above.

Version Seq Len Before Forward Before Backward Before Opt Max Reserved
4.37.2 1280 3.90 GiB 15.54 GiB 15.89 GiB 18.79 GiB
4.37.2 1536 3.90 GiB 18.21 GiB 18.60 GiB 22.08 GiB
4.39.3 1280 3.99 GiB 15.92 GiB 17.06 GiB 19.61 GiB
4.39.3 1536 - - - OOM

4.39.3 uses almost a GB more memory at a sequence length of 1280 and errors out at 1536.

Using torch.cuda.memory._record_memory_history shows a peak in memory usage in 4.37.2 at the start of the backward pass as expected. (Use the --peak_memory flag in the above script).

image

Switching to 3.49.3 shows an unexpected result: consistently large memory spikes during the backward pass of the SDPA kernel:

image

When sharding Llama across two cards with FSDP and gradient checkpointing, the memory spikes become quite visible as ~10GB outliers:

image

There are no spikes in 4.37 with FSDP.

The culprit appears to be the switch from using the SDPA Flash Attention kernel in 4.37.2 to the SDPA Efficient kernel in 3.38 & 4.39.

# Transformers 4.37.2 LlamaSdpaAttention.forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states,
    key_states,
    value_states,
    attn_mask=attention_mask,
    dropout_p=self.attention_dropout if self.training else 0.0,
    is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
# Transformers 4.39.3 LlamaSdpaAttention.forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states,
    key_states,
    value_states,
    attn_mask=causal_mask,
    dropout_p=self.attention_dropout if self.training else 0.0,
)

By removing the is_causal and using a custom causal_mask instead, scaled_dot_product_attention is now using the more memory hungry Efficient kernel instead of the memory efficient Flash Attention 2 kernel.

You can spot check this by using LlamaFlashAttention2 in the reproduction script by using --flash_attn flag.

Version Seq Len Before Forward Before Backward Before Opt Max Reserved
4.37.2 1536 3.90 GiB 18.21 GiB 18.60 GiB 22.08 GiB
4.39.3 1536 3.99 GiB 18.36 GiB 18.75 GiB 22.24 GiB

With LlamaFlashAttention2, 4.39 uses a moderate amount more memory than 4.37. Although this may increase to a significant amount at longer context sizes.

When training, LlamaSdpaAttention should use is_causal=True if there isn't an attention mask passed to Llama instead of creating causal_mask.

The switch to a custom causal_mask also causes torch.compile errors which are not present in 4.37 due to data type mismatches in 4.39, particularly when training in pure bfloat16 instead of mixed precision.

@ArthurZucker
Copy link
Collaborator

That is a really great report, thanks! 🤗
cc @fxmarty I think this is important, funny that we talked about it yesterday!
Let's work toward finding a solution as these spikes seem abnormal. You are right, the culprit is the switch of kernels.

  1. torch.compile is working for inference, which was not possible before 4.38, I am highly surprised that it does not work for training. We ought to add a test for this. cc @fxmarty let's fix that asap
  2. The spikes for 4.39 seem to induce a ~1GB increase form the spikes for 4.37. Thus I am wondering where the significantly more memory comes from, as it seems to be a FSDP related problem
  3. self.is_causal and attention_mask is None and q_len > 1, was already a control flow blocking a few things. Down to re-enable the memory efficient kernel, but without breaking the rest, which is a lot of work! Also for a 1GB surplus, it's not very high priority.

TLDR; there must be something a bit wrong for FSDP

@fxmarty
Copy link
Contributor

fxmarty commented Apr 3, 2024

The culprit appears to be the switch from using the SDPA Flash Attention kernel in 4.37.2 to the SDPA Efficient kernel in 3.38 & 4.39.

Yes, that's likely the issue.

@johnowhitaker
Copy link

FSDP is not the issue, it's just in that case we're using a lot less memory for weights and doing gradient checkpointing, so the spikes show up better.

@ArthurZucker
Copy link
Collaborator

@fxmarty can you have a look? 🤗

@warner-benjamin
Copy link
Contributor Author

warner-benjamin commented Apr 4, 2024

As @johnowhitaker mentioned, FSDP is not the issue.

2. The spikes for 4.39 seem to induce a ~1GB increase form the spikes for 4.37. Thus I am wondering where the significantly more memory comes from, as it seems to be a FSDP related problem

3. self.is_causal and attention_mask is None and q_len > 1, was already a control flow blocking a few things. Down to re-enable the memory efficient kernel, but without breaking the rest, which is a lot of work! Also for a 1GB surplus, it's not very high priority.

TLDR; there must be something a bit wrong for FSDP

@ArthurZucker, I couldn't fit a large enough sequence length with 4-bit Llama 7B, so I'll illustrate with a single Attention layer. It will use an attention mask and thus the SDPA efficient kernel if sdpa_flash=False, and the SDPA flash attention kernel using is_causal=True if sdpa_flash=True.

class Attention(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int = 2048,
        num_heads: int = 16,
        seq_len: int = 2048,
        sdpa_flash: bool = False,
    ):
        super().__init__()
        self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
        self.Wo = nn.Linear(hidden_size, hidden_size, bias=False)
        self.nh = num_heads
        self.sdpa_flash = sdpa_flash

        if not sdpa_flash:
            self.register_buffer(
                "causal_mask",
                torch.triu(torch.ones([seq_len, seq_len], dtype=torch.bool), diagonal=1)
                .logical_not()
                .view(1, 1, seq_len, seq_len),
            )

    def forward(self, x):
        B, S, C = x.shape
        x = self.Wqkv(x).reshape(B, S, 3, self.nh, C // self.nh)
        query, key, value = x.transpose(3, 1).unbind(dim=2)

        attn = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=self.causal_mask[:, :, :S, :S] if not self.sdpa_flash else None,
            is_causal=self.sdpa_flash,
        )

        return self.Wo(attn.transpose(1, 2).reshape(B, S, C))

At a batch size of 128, is_causal=True is still using less memory during the backward pass then using a attn_mask with a batch size of 16. The Flash Attention kernel with BS=16 uses 15.65 GiB less memory for the backward pass. All of these results are with a sequence length of 2048.

Version Batch Size Pre-Forward Forward Pass Backward Pass Max Reserved
attn_mask 16 0.04 GiB 1.06 GiB 17.82 GiB 17.82 GiB
is_causal 16 0.04 GiB 1.04 GiB 2.17 GiB 2.17 GiB
is_causal 128 0.04 GiB 8.08 GiB 16.10 GiB 16.10 GiB

In addition to LlamaSdpaAttention needing to be fixed, it looks like both CohereSdpaAttention #29622 and GemmaSdpaAttention #29167 have the same implementation and thus training memory issues. All the rest of the SdpaAttention layers seem to be using is_causal=True.

Here is the script and commands to recreate the single Attention layer memory measurements:

python example.py --batch_size 16
python example.py --batch_size 16 --sdpa_flash
python example.py --batch_size 128 --sdpa_flash
import argparse
from random import randint

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from datasets import Dataset

torch.set_float32_matmul_precision("high")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_sequence_length", type=int, default=2048)
    parser.add_argument("--dataset_size", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--vocab_size", type=int, default=64)
    parser.add_argument("--hidden_size", type=int, default=2048)
    parser.add_argument("--num_heads", type=int, default=16)
    parser.add_argument("--sdpa_flash", action="store_true")
    parser.add_argument("--torch_compile", action="store_true")
    parser.add_argument("--profile_memory", action="store_true")

    return parser.parse_args()


def get_dataset(dataset_size, sequence_length, vocab_size=64):
    dataset = Dataset.from_dict(
        {"input_ids": [[randint(0, vocab_size) for _ in range(sequence_length)] for i in range(0, dataset_size)]}
    )
    return dataset


def data_collator(batch):
    batch = torch.stack([b["input_ids"] for b in batch])
    return {"input_ids": batch, "labels": batch}


def append_stats(batch, stats):
    if batch == 0:
        stats.append(torch.cuda.memory_reserved(0))


class Attention(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int = 2048,
        num_heads: int = 16,
        seq_len: int = 2048,
        sdpa_flash: bool = False,
    ):
        super().__init__()
        self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
        self.Wo = nn.Linear(hidden_size, hidden_size, bias=False)
        self.nh = num_heads
        self.sdpa_flash = sdpa_flash

        if not sdpa_flash:
            self.register_buffer(
                "causal_mask",
                torch.triu(torch.ones([seq_len, seq_len], dtype=torch.bool), diagonal=1)
                .logical_not()
                .view(1, 1, seq_len, seq_len),
            )

    def forward(self, x):
        B, S, C = x.shape
        x = self.Wqkv(x).reshape(B, S, 3, self.nh, C // self.nh)
        query, key, value = x.transpose(3, 1).unbind(dim=2)

        attn = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=self.causal_mask[:, :, :S, :S] if not self.sdpa_flash else None,
            is_causal=self.sdpa_flash,
        )

        return self.Wo(attn.transpose(1, 2).reshape(B, S, C))


class OneLayerTransformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        num_heads: int,
        seq_len: int,
        sdpa_flash: bool = False,
        loss_fn: nn.Module = nn.CrossEntropyLoss(),
    ):
        super().__init__()
        self.loss_fn = loss_fn
        self.We = nn.Embedding(vocab_size, hidden_size)
        self.attn = Attention(hidden_size, num_heads, seq_len, sdpa_flash)
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, input_ids, labels):
        x = self.We(input_ids)
        x = x + self.attn(self.norm(x))

        return self.loss_fn(x.view(-1, x.shape[-1]), labels.view(-1))


def main():
    args = parse_args()
    print(args)
    stats = []

    model = OneLayerTransformer(
        vocab_size=args.vocab_size,
        hidden_size=args.hidden_size,
        num_heads=args.num_heads,
        seq_len=args.max_sequence_length,
        sdpa_flash=args.sdpa_flash,
    ).to(dtype=torch.bfloat16, device="cuda")

    optimizer = optim.SGD(model.parameters(), lr=0.001)

    dataset = get_dataset(
        args.dataset_size * args.batch_size,
        args.max_sequence_length,
        args.vocab_size - 1,
    ).with_format("torch")

    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        collate_fn=data_collator,
    )

    if args.torch_compile:
        print("Compiling model")
        model = torch.compile(model)

    torch.cuda.reset_peak_memory_stats(0)
    for i, batch in enumerate(tqdm(dataloader)):
        input_ids = batch["input_ids"].to("cuda")
        labels = batch["labels"].to("cuda")

        if args.profile_memory and i == 0:
            torch.cuda.memory._record_memory_history()

        append_stats(i, stats)
        loss = model(input_ids=input_ids, labels=labels)

        append_stats(i, stats)
        loss.backward()

        append_stats(i, stats)
        optimizer.step()
        optimizer.zero_grad()

        if args.profile_memory and i == 0:
            torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
            torch.cuda.memory._record_memory_history(enabled=None)

    for label, stat in zip(["forward", "backward", "optimizer"], stats):
        print(f"Before {label}: {stat/2**30:.2f} GiB")
    print(f"Max Reserved: {torch.cuda.max_memory_reserved(0)/2**30:.2f} GiB")


if __name__ == "__main__":
    main()

@fxmarty
Copy link
Contributor

fxmarty commented Apr 5, 2024

Hi @warner-benjamin, this will be fixed in #30070 in case you are not using torch.compile.

In case you use torch.compile, we will first need pytorch/pytorch#120400 (only disallow FA2 for compile with fullgraph=True).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants