In [7]:
import random
import torch

import torch.nn.functional as F

from torch.nn.attention.flex_attention import (
    _DEFAULT_SPARSE_BLOCK_SIZE,
    create_block_mask,
    create_mask,
    flex_attention,
)
from functools import lru_cache, partial

@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
    print(block_mask.to_dense())
    print(block_mask.shape)
    return block_mask

random.seed(0)
torch.manual_seed(0)

batch_size = 4
n_heads = 6
D = 64


def build_seq_idx(tensor: torch.Tensor):
    offsets = tensor.offsets()
    total_length = tensor.offsets()[-1].item()
    print("total_length:", total_length)
    # Create a range tensor from 0 to total_length
    range_tensor = torch.arange(total_length, device="cuda", dtype=torch.int32)

    # Use searchsorted to find the index for each position
    seq_idx = torch.searchsorted(offsets, range_tensor, right=True) - 1

    return seq_idx


def create_njt_wrapper(orig_mask_mod, offsets, seq_idx):
    """Generic Wrapper that converts Dense mask_mod functions to NJT mask_mod functions"""

    def njt_score_mod(b, h, q_idx, kv_idx):
        q_nested = q_idx - offsets[seq_idx[q_idx]]
        kv_nested = kv_idx - offsets[seq_idx[kv_idx]]
        is_same_sequence = seq_idx[q_idx] == seq_idx[kv_idx]
        return orig_mask_mod(b, h, q_nested, kv_nested) & is_same_sequence

    return njt_score_mod


# Dense Score Mod
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx
    # return torch.where(q_idx >= kv_idx, score, -float("inf"))


# Current limitation that the total sequnce length must be divisible by 128
sentence_lengths = [random.randint(1, 1024) for _ in range(batch_size - 1)]
print("sentence lengths:", sentence_lengths)
total = sum(sentence_lengths)
sentence_lengths.append(128 - total % 128)
total = sum(sentence_lengths)
print("actual total:", total)

ragged_tensors = [torch.randn(l, n_heads, D, device="cuda") for l in sentence_lengths]
query = torch.nested.nested_tensor(
    ragged_tensors, layout=torch.jagged, requires_grad=True
)
key = torch.nested.nested_tensor(
    ragged_tensors, layout=torch.jagged, requires_grad=True
)
value = torch.nested.nested_tensor(
    ragged_tensors, layout=torch.jagged, requires_grad=True
)

# Build the seq_idx lookup table for
offsets = query.offsets()
seq_idx = build_seq_idx(query)

causal_score_mod_njt = create_njt_wrapper(causal_mask, offsets, seq_idx)

# print("query_values:", query_values)
# print("key_values:", key_values)
# print("value_values:", value_values)

# print("query:", query)
# print("key:", key)
# print("value:", value)

print("query.shape:", query.shape)
print("key.shape:", key.shape)
print("value.shape:", value.shape)

block_mask = create_block_mask_cached(
    causal_score_mod_njt, 1, 1, total, total, device=query.device
)
query_flex = query.transpose(1, 2).detach().requires_grad_()
key_flex = key.transpose(1, 2).detach().requires_grad_()
value_flex = value.transpose(1, 2).detach().requires_grad_()

print("query_flex:", query_flex.shape)
print("key_flex:", key_flex.shape)
print("value_flex:", value_flex.shape)

out_flex = flex_attention(
    query_flex,
    key_flex,
    value_flex,
    block_mask=block_mask,
)
print("out_flex:", out_flex.shape)
out_sdpa = F.scaled_dot_product_attention(
    query.transpose(1, 2),
    key.transpose(1, 2),
    value.transpose(1, 2),
    is_causal=True,
)
print("out_sdpa:", out_sdpa.shape)

out_sdpa

sdpa_outs = []
flex_outs = []

gradOut = torch.randn_like(out_sdpa)

sdpa_outs.append(out_sdpa)
out_sdpa.backward(gradOut)
sdpa_outs += [query.grad, key.grad, value.grad]

flex_outs.append(out_flex)
out_flex.backward(gradOut._values.unsqueeze(0))
flex_outs += [query_values.grad, key_values.grad, value_values.grad]

for flex, sdpa in zip(flex_outs, sdpa_outs):
    print("flex, sdpa")
    print(flex.shape, sdpa.shape)

    flex = flex.squeeze(0)
    torch.testing.assert_close(flex, sdpa._values, atol=1e-2, rtol=1e-2)


print("Correctness check passed ✅")

print(block_mask)

sentence lengths: [789, 862, 83]
actual total: 1792
total_length: 1792
query.shape: torch.Size([4, j19, 6, 64])
key.shape: torch.Size([4, j20, 6, 64])
value.shape: torch.Size([4, j21, 6, 64])
tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
          [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]]]], device='cuda:0',
       dtype=torch.int32)
(1, 1,

Unsupported: Failed running call_method new_empty(*(NestedTensor(size=(4, 6, s2, 64), offsets=FakeTensor(..., device='cuda:0', size=(5,), dtype=torch.int64), requires_grad=True, contiguous=True), []), **{'dtype': torch.int32}):
aten.new_empty.default

from user code:
   File "/home/kkj/axolotl/.venv/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py", line 1033, in _flex_attention_hop_wrapper
    return flex_attention_hop(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In [2]:
random.seed(0)
torch.manual_seed(0)

batch_size = 4
n_heads = 12
D = 64

@lru_cache
def create_block_mask_cached(mask_mod, B, H, M, N, device="cuda"):
    block_mask = create_block_mask(mask_mod, B, H, M, N, device=device)
    return block_mask

def prepare_qkv_values(tensor):
    return tensor._values.detach().requires_grad_()

def build_seq_idx(tensor: torch.Tensor):
    offsets = tensor.offsets()
    total_length = tensor.offsets()[-1].item()
    # Create a range tensor from 0 to total_length
    range_tensor = torch.arange(total_length, device="cuda", dtype=torch.int32)

    # Use searchsorted to find the index for each position
    seq_idx = torch.searchsorted(offsets, range_tensor, right=True) - 1

    return seq_idx

def create_njt_wrapper(seq_idx):
    """Generic Wrapper that makes a NJT mask_mod"""

    def njt_mask_mod(b, h, q_idx, kv_idx):
        is_same_sequence = seq_idx[q_idx] == seq_idx[kv_idx]
        return is_same_sequence

    return njt_mask_mod

# Current limitation that the total sequnce length must be divisible by 128
sentence_lengths = [random.randint(1, 1024) for _ in range(batch_size - 1)]
print("sentence lengths:", sentence_lengths)
total = sum(sentence_lengths)
sentence_lengths.append(128 - total % 128)
total = sum(sentence_lengths)

ragged_tensors = [torch.randn(l, n_heads, D, device="cuda") for l in sentence_lengths]
query = torch.nested.nested_tensor(
    ragged_tensors, layout=torch.jagged, requires_grad=True
)
key = torch.nested.nested_tensor(
    ragged_tensors, layout=torch.jagged, requires_grad=True
)
value = torch.nested.nested_tensor(
    ragged_tensors, layout=torch.jagged, requires_grad=True
)

# Build the seq_idx lookup table for
offsets = query.offsets()
seq_idx = build_seq_idx(query)

mask_mod_njt = create_njt_wrapper(seq_idx)

query_values = prepare_qkv_values(query)
key_values = prepare_qkv_values(key)
value_values = prepare_qkv_values(value)

print("query_values.shape:", query_values.shape)
print("key_values.shape:", key_values.shape)
print("value_values.shape:", value_values.shape)

print("query.shape:", query.shape)
print("key.shape:", key.shape)
print("value.shape:", value.shape)

block_mask = create_block_mask_cached(
    mask_mod_njt, 1, 1, total, total, device=query_values.device
)
out_flex = flex_attention(
    query_values.view(1, -1, n_heads, D).transpose(1, 2),
    key_values.view(1, -1, n_heads, D).transpose(1, 2),
    value_values.view(1, -1, n_heads, D).transpose(1, 2),
    block_mask=block_mask,
)

print("Flex attention can run for a batch size of", batch_size, "with", n_heads, "heads, a dim of", D, "and a sequence length sum of", total)

print(block_mask)

sentence lengths: [789, 862, 83]
query_values.shape: torch.Size([1792, 12, 64])
key_values.shape: torch.Size([1792, 12, 64])
value_values.shape: torch.Size([1792, 12, 64])
query.shape: torch.Size([4, j4, 12, 64])
key.shape: torch.Size([4, j5, 12, 64])
value.shape: torch.Size([4, j6, 12, 64])
Flex attention can run for a batch size of 4 with 12 heads, a dim of 64 and a sequence length sum of 1792
BlockMask(shape=(1, 1, 1792, 1792), sparsity=48.98%, 
(0, 0)
██████████████              
██████████████              
██████████████              
██████████████              
██████████████              
██████████████              
██████████████████████████  
            ██████████████  
            ██████████████  
            ██████████████  
            ██████████████  
            ██████████████  
            ████████████████
                        ████
)
