In [12]:
from torch.nn.attention.flex_attention import flex_attention
import torch

In [13]:
def noop(score, b, h, q_idx, kv_idx):
    return score

In [14]:
b,s,d = 2,16,32
heads = 2
headdim = d // heads

In [15]:
q,k,v = [torch.randn(b, s, heads, headdim, device='cuda', requires_grad=True) for _ in range(3)]

In [16]:
output = torch.compile(flex_attention)(q, k, v, noop)

In [17]:
output.mean().backward()

In [None]:
from torch import Tensor
from typing import List, Callable

def generate_doc_mask(seq_lens: List[int] | Tensor, device="cpu") -> Callable:
    """
    Generates a document mask function for flex attention.
    The first step is to create a flat tensor of length `sum(seq_lens)` where each position
    is assigned a document ID based on the sequence lengths provided.

    Args:
        seq_lens (List[int] | Tensor): A list or tensor of sequence lengths for each batch element.

    Returns:
        Callable: A mask function that can be passed to flex_attention.
    """
    if isinstance(seq_lens, list):
        seq_lens = torch.tensor(seq_lens, device=device)

    total_len = seq_lens.sum().item()
    document_ids = torch.repeat_interleave(
        torch.arange(len(seq_lens), device=device), seq_lens
    )

    def doc_mask_mod(b, h, q_idx, kv_idx):
        return document_ids[q_idx] == document_ids[kv_idx]

    return doc_mask_mod, document_ids

In [22]:
sequences = [2, 5, 3, 6]

In [23]:
# first, we create each individual sequence
docs = []
for seq_len in sequences:
    q,k,v = [torch.randn(1, seq_len, heads, headdim, device='cuda', requires_grad=True) for _ in range(3)]
    docs.append((q,k,v))

# then, we pass them through one by one to get a reference output
ref_outputs = []
for q,k,v in docs:
    out = flex_attention(q, k, v, noop)
    ref_outputs.append(out)

# now, we concatenate them into one big batch
q = torch.cat([doc[0] for doc in docs], dim=1)
k = torch.cat([doc[1] for doc in docs], dim=1)
v = torch.cat([doc[2] for doc in docs], dim=1)
mask_fn, doc_ids = generate_doc_mask(sequences, device='cuda')
output = flex_attention(q, k, v, mask_fn)

# if we combine the reference outputs, they should match the big output
ref_output = torch.cat(ref_outputs, dim=1)
print(torch.allclose(output, ref_output))

True


In [24]:
# same thing, but with torch.compile
# first, we create each individual sequence
docs = []
for seq_len in sequences:
    q,k,v = [torch.randn(1, seq_len, heads, headdim, device='cuda', requires_grad=True) for _ in range(3)]
    docs.append((q,k,v))

# then, we pass them through one by one to get a reference output
ref_outputs = []
for q,k,v in docs:
    out = torch.compile(flex_attention)(q, k, v, noop)
    ref_outputs.append(out)

# now, we concatenate them into one big batch
q = torch.cat([doc[0] for doc in docs], dim=1)
k = torch.cat([doc[1] for doc in docs], dim=1)
v = torch.cat([doc[2] for doc in docs], dim=1)
mask_fn, doc_ids = generate_doc_mask(sequences, device='cuda')
output = torch.compile(flex_attention)(q, k, v, mask_fn)

# if we combine the reference outputs, they should match the big output
ref_output = torch.cat(ref_outputs, dim=1)
print(torch.allclose(output, ref_output))

True
