In [1]:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
import torch

In [2]:
def noop(score, b, h, q_idx, kv_idx):
    return score

In [3]:
b,s,d = 2,16,32
heads = 2
headdim = d // heads

In [4]:
q,k,v = [torch.randn(b, s, heads, headdim, device='cuda', requires_grad=True) for _ in range(3)]

In [5]:
output = torch.compile(flex_attention)(q, k, v, noop)

  return _C._get_float32_matmul_precision()


In [6]:
output.mean().backward()

In [7]:
from torch import Tensor
from typing import List, Callable

def generate_doc_mask(seq_lens: List[int] | Tensor, device="cpu") -> Callable:
    """
    Generates a document mask function for flex attention.
    The first step is to create a flat tensor of length `sum(seq_lens)` where each position
    is assigned a document ID based on the sequence lengths provided.

    Args:
        seq_lens (List[int] | Tensor): A list or tensor of sequence lengths for each batch element.

    Returns:
        Callable: A mask function that can be passed to flex_attention.
    """
    if isinstance(seq_lens, list):
        seq_lens = torch.tensor(seq_lens, device=device)

    total_len = seq_lens.sum().item()
    document_ids = torch.repeat_interleave(
        torch.arange(len(seq_lens), device=device), seq_lens
    )

    def doc_mask_mod(b, h, q_idx, kv_idx):
        return document_ids[q_idx] == document_ids[kv_idx]

    return doc_mask_mod, document_ids

In [8]:
sequences = [2, 5, 3, 6]

In [10]:
seq_lens = torch.tensor(sequences)

In [68]:
total_lens = seq_lens.sum().item()
document_ids = torch.repeat_interleave(
    torch.arange(len(seq_lens), device=seq_lens.device), seq_lens
)

In [74]:
mask = torch.zeros_like(document_ids)

In [70]:
mask, document_ids

(tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 tensor([0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3]))

In [None]:
expected =       torch.tensor([1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1])
probs_expected = torch.tensor([0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0])

In [59]:
probs_mask = torch.zeros_like(document_ids)

In [60]:
probs_mask[1:] = document_ids[:-1] != document_ids[1:]

In [63]:
probs_mask == probs_expected

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])

In [64]:
probs = torch.rand(1, mask.shape[0], requires_grad=True)

In [65]:
probs

tensor([[0.9147, 0.1161, 0.8733, 0.6780, 0.8278, 0.8219, 0.2384, 0.4243, 0.6202,
         0.2918, 0.0021, 0.7642, 0.1010, 0.7634, 0.6756, 0.0488]],
       requires_grad=True)

In [66]:
probs = torch.where(probs_mask == 1, 1, probs)

In [67]:
probs

tensor([[0.9147, 0.1161, 1.0000, 0.6780, 0.8278, 0.8219, 0.2384, 1.0000, 0.6202,
         0.2918, 1.0000, 0.7642, 0.1010, 0.7634, 0.6756, 0.0488]],
       grad_fn=<WhereBackward0>)

In [81]:
bounds = probs > 0.5

In [82]:
bounds

tensor([[ True, False,  True,  True,  True,  True, False,  True,  True, False,
          True,  True, False,  True,  True, False]])

In [77]:
mask = -1 * (probs_mask - 1)

In [78]:
torch.all(mask == expected)

tensor(True)

In [79]:
gates = torch.rand(1, mask.shape[0], requires_grad=True)

In [80]:
gates * mask

tensor([[0.3140, 0.2988, 0.0000, 0.2349, 0.4341, 0.0519, 0.9304, 0.0000, 0.7086,
         0.2181, 0.0000, 0.6826, 0.2217, 0.8132, 0.7034, 0.4184]],
       grad_fn=<MulBackward0>)

In [11]:
# first, we create each individual sequence
docs = []
for seq_len in sequences:
    q,k,v = [torch.randn(1, heads, seq_len, headdim, device='cuda', requires_grad=True) for _ in range(3)]
    docs.append((q,k,v))

# then, we pass them through one by one to get a reference output
ref_outputs = []
for q,k,v in docs:
    out = flex_attention(q, k, v, noop)
    ref_outputs.append(out)

# now, we concatenate them into one big batch
q = torch.cat([doc[0] for doc in docs], dim=2)
k = torch.cat([doc[1] for doc in docs], dim=2)
v = torch.cat([doc[2] for doc in docs], dim=2)
print(q.shape, k.shape, v.shape)
mask_fn, doc_ids = generate_doc_mask(sequences, device='cuda')
block_mask = create_block_mask(mask_fn, B=None, H=None, Q_LEN=q.shape[2], KV_LEN=k.shape[2], device='cuda')
output = flex_attention(q, k, v, block_mask=block_mask)

# if we combine the reference outputs, they should match the big output
ref_output = torch.cat(ref_outputs, dim=2)
print(torch.allclose(output, ref_output))


SOLUTION: Use torch.compile(flex_attention)(...)

If you want to debug your score_mod/mask_mod, you can set:
torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True

This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results.
  _warn_once(


torch.Size([1, 2, 16, 16]) torch.Size([1, 2, 16, 16]) torch.Size([1, 2, 16, 16])
True


In [None]:
# same thing, but with torch.compile
# first, we create each individual sequence
docs = []
for seq_len in sequences:
    q,k,v = [torch.randn(1, heads, seq_len, headdim, device='cuda', requires_grad=True) for _ in range(3)]
    docs.append((q,k,v))

# then, we pass them through one by one to get a reference output
ref_outputs = []
for q,k,v in docs:
    out = torch.compile(flex_attention)(q, k, v, noop)
    ref_outputs.append(out)

# now, we concatenate them into one big batch
q = torch.cat([doc[0] for doc in docs], dim=2)
k = torch.cat([doc[1] for doc in docs], dim=2)
v = torch.cat([doc[2] for doc in docs], dim=2)
mask_fn, doc_ids = generate_doc_mask(sequences, device='cuda')
block_mask = create_block_mask(mask_fn, B=None, H=None, Q_LEN=q.shape[2], KV_LEN=k.shape[2], device='cuda')
output = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)

# if we combine the reference outputs, they should match the big output
ref_output = torch.cat(ref_outputs, dim=2)
print(torch.allclose(output, ref_output))

True


  sequences = torch.tensor(sequences, device='cuda')


In [13]:
sequences

[2, 5, 3, 6]

In [14]:
### causal mask
def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

def generate_doc_mask_mod(mask_mod, document_id):
    # can feed in another mask modifier function such as `causal`

    # Get unique document IDs and their counts
    _, counts = torch.unique_consecutive(document_id, return_counts=True)
    # Create cumulative counts (offsets)
    offsets = torch.cat([torch.tensor([0], device=document_id.device), counts.cumsum(0)[:-1]])
    def doc_mask_wrapper(b, h, q_idx, kv_idx):
        same_doc = document_id[q_idx] == document_id[kv_idx]
        q_logical = q_idx - offsets[document_id[q_idx]]
        kv_logical = kv_idx - offsets[document_id[kv_idx]]
        inner_mask = mask_mod(b, h, q_logical, kv_logical)
        return same_doc & inner_mask
    return doc_mask_wrapper