In [1]:
from datasets import load_dataset

dataset_config_name = "wikitext-2-raw-v1"
dataset_name = "wikitext"
raw_datasets = load_dataset(
    dataset_name,
    dataset_config_name,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

### Document attention (flash attention)

- https://github.com/Dao-AILab/flash-attention/issues/654

In [1]:
import torch
from flash_attn import flash_attn_varlen_func

# Assume we have 3 sequences of varying lengths
seq_lens = [512, 1024, 256]
batch_size = len(seq_lens)
total_tokens = sum(seq_lens)

# --- FIX IS HERE ---
# We must define the head structure explicitly.
# The `flash_attn` function expects inputs shaped for multi-head attention.
num_heads = 8
head_dim = 16  # The dimension of each attention head
hidden_dim = num_heads * head_dim  # This would be 128 in our case

# Create 3D packed tensors with the correct shape: (total_tokens, num_heads, head_dim)
q = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
k = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
v = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
# --------------------

# Create the all-important cumulative sequence length tensor
# This part remains the same.
cu_seqlens = torch.tensor(
    [0] + list(torch.cumsum(torch.tensor(seq_lens), 0)),
    dtype=torch.int32,
    device="cuda",
)
# cu_seqlens will be: tensor([0, 512, 1536, 1792], device='cuda:0', dtype=torch.int32)

# Get the max sequence length in the batch
max_seqlen = max(seq_lens)

# Call the variable-length (packed) version of flash attention
# This call now works because the input tensors have the correct 3D shape.
output = flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q=cu_seqlens,
    cu_seqlens_k=cu_seqlens,
    max_seqlen_q=max_seqlen,
    max_seqlen_k=max_seqlen,
    causal=True,  # For decoder models
)

print("Shape of the output tensor:", output.shape)

# In a real transformer block, you would reshape the output back
# to combine the heads before passing to the feed-forward network.
output_reshaped = output.view(total_tokens, hidden_dim)
print("Shape after reshaping to combine heads:", output_reshaped.shape)

Shape of the output tensor: torch.Size([1792, 8, 16])
Shape after reshaping to combine heads: torch.Size([1792, 128])


In [None]:
from trl import SFTConfig, SFTTrainer