In [1]:
from datasets import load_dataset

dataset_config_name = "wikitext-2-raw-v1"
dataset_name = "wikitext"
raw_datasets = load_dataset(
    dataset_name,
    dataset_config_name,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

### Document attention (flash attention)

- https://github.com/Dao-AILab/flash-attention/issues/654

In [1]:
import torch
from flash_attn import flash_attn_varlen_func

# Assume we have 3 sequences of varying lengths
seq_lens = [512, 1024, 256]
batch_size = len(seq_lens)
total_tokens = sum(seq_lens)

# --- FIX IS HERE ---
# We must define the head structure explicitly.
# The `flash_attn` function expects inputs shaped for multi-head attention.
num_heads = 8
head_dim = 16  # The dimension of each attention head
hidden_dim = num_heads * head_dim  # This would be 128 in our case

# Create 3D packed tensors with the correct shape: (total_tokens, num_heads, head_dim)
q = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
k = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
v = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
# --------------------

# Create the all-important cumulative sequence length tensor
# This part remains the same.
cu_seqlens = torch.tensor(
    [0] + list(torch.cumsum(torch.tensor(seq_lens), 0)),
    dtype=torch.int32,
    device="cuda",
)
# cu_seqlens will be: tensor([0, 512, 1536, 1792], device='cuda:0', dtype=torch.int32)

# Get the max sequence length in the batch
max_seqlen = max(seq_lens)

# Call the variable-length (packed) version of flash attention
# This call now works because the input tensors have the correct 3D shape.
output = flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q=cu_seqlens,
    cu_seqlens_k=cu_seqlens,
    max_seqlen_q=max_seqlen,
    max_seqlen_k=max_seqlen,
    causal=True,  # For decoder models
)

print("Shape of the output tensor:", output.shape)

# In a real transformer block, you would reshape the output back
# to combine the heads before passing to the feed-forward network.
output_reshaped = output.view(total_tokens, hidden_dim)
print("Shape after reshaping to combine heads:", output_reshaped.shape)

Shape of the output tensor: torch.Size([1792, 8, 16])
Shape after reshaping to combine heads: torch.Size([1792, 128])


In [3]:
from transformers import DataCollatorWithFlattening
from datasets import load_dataset

train_dataset = load_dataset("microsoft/orca-math-word-problems-200k")
train_dataset = train_dataset["train"]

# use DataCollatorWithFlattening

data_collator = DataCollatorWithFlattening()

In [2]:
train_dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 200035
})

In [4]:
from transformers import AutoTokenizer

model_name = "unsloth/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
examples = [tokenizer(train_dataset[i]["question"]) for i in range(4)]
# example['labels'] = example['input_ids']
examples[0]

{'input_ids': [128000, 41, 2234, 74, 1982, 374, 279, 220, 20, 339, 2035, 13, 7531, 279, 1396, 315, 1274, 889, 28129, 279, 6381, 1584, 10819, 1109, 50432, 74, 1982, 13], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [16]:
# dict(example)

In [24]:
data_collator(examples)

{'input_ids': tensor([[128000,     41,   2234,     74,   1982,    374,    279,    220,     20,
             339,   2035,     13,   7531,    279,   1396,    315,   1274,    889,
           28129,    279,   6381,   1584,  10819,   1109,  50432,     74,   1982,
              13, 128000,     32,   1396,  18255,    555,    220,    605,    374,
             220,     21,     13,  44188,    647,     72,   2751,    279,   1121,
             555,  33356,    287,    220,    868,    505,    264,   3738,   1396,
              13,   3639,    374,    279,   1121,    568,   2751,     30, 128000,
              35,    647,   8783,  50243,    264,   6710,    315,   5684,    449,
             264,   1396,   5439,    389,    433,     11,    323,   6944,    311,
            1304,    264,   2380,  49442,   1396,    555,  25012,    279,   1176,
            4183,   1396,    304,    279,  11758,   2035,     11,    279,   2132,
            4183,    304,    279,  22781,   2035,     11,    323,    279,   4948,
   

In [8]:
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map={"": 0},
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

In [27]:
model_input = data_collator(examples)
for key in model_input.keys():
    model_input[key] = model_input[key].to("cuda")
    model_input[key] = torch.cat(
        [model_input[key], model_input[key]],
        dim=0,
    )

In [28]:
model_input['input_ids'].shape

torch.Size([2, 197])

In [29]:
with torch.no_grad():
    model_result = model(**model_input)
model_result

CausalLMOutputWithPast(loss=tensor(2.6010, device='cuda:0'), logits=tensor([[[ 2.8438,  3.5625,  7.0000,  ..., -1.2500, -1.2500, -1.2500],
         [ 7.8438,  5.3125,  5.8438,  ...,  0.1299,  0.1309,  0.1309],
         [ 7.8438,  5.5625,  4.2812,  ..., -1.7031, -1.7031, -1.7031],
         ...,
         [11.1250,  6.5625,  6.1250,  ..., -0.0830, -0.0830, -0.0830],
         [12.1875,  6.3438,  6.0312,  ..., -0.4082, -0.4082, -0.4082],
         [ 7.0000,  2.2500,  5.8438,  ..., -1.2734, -1.2734, -1.2734]],

        [[ 2.8438,  3.5625,  7.0000,  ..., -1.2500, -1.2500, -1.2500],
         [ 7.8438,  5.3125,  5.8438,  ...,  0.1299,  0.1309,  0.1309],
         [ 7.8438,  5.5625,  4.2812,  ..., -1.7031, -1.7031, -1.7031],
         ...,
         [11.1250,  6.5625,  6.1250,  ..., -0.0830, -0.0830, -0.0830],
         [12.1875,  6.3438,  6.0312,  ..., -0.4082, -0.4082, -0.4082],
         [ 7.0000,  2.2500,  5.8438,  ..., -1.2734, -1.2734, -1.2734]]],
       device='cuda:0', dtype=torch.bfloat16), p

In [30]:
model_input["input_ids"].shape

torch.Size([2, 197])

In [17]:
model_result.logits.shape

torch.Size([1, 197, 128256])