In [27]:
import torch

torch.logical_and(torch.ones(8, 10).triu_(), torch.ones(8, 10).tril_())

tensor([[ True, False, False, False, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False, False, False],
        [False, False,  True, False, False, False, False, False, False, False],
        [False, False, False,  True, False, False, False, False, False, False],
        [False, False, False, False,  True, False, False, False, False, False],
        [False, False, False, False, False,  True, False, False, False, False],
        [False, False, False, False, False, False,  True, False, False, False],
        [False, False, False, False, False, False, False,  True, False, False]])

In [28]:
torch.ones(8, 10).triu_()

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.]])

In [29]:
torch.ones(8, 10).tril_()

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.]])

In [42]:
def get_mask_from_ranges(q_ranges, k_ranges, attn_type_map, q_len, k_len):
    bsz = q_ranges.shape[0]
    mask = torch.zeros((q_len, k_len), device='cuda', dtype=torch.bool)
    for i in range(bsz):
        if attn_type_map[i] == 1:
            mask_slice = mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]]
            short_len = min(mask_slice.shape[0], mask_slice.shape[1])
            causal_part = torch.ones(short_len, short_len, device=mask_slice.device, dtype=mask_slice.dtype).tril_()
            mask_slice[-short_len:, -short_len:] = causal_part
            mask_slice[:, :-short_len] = True
            mask_slice[:-short_len, :] = False
        elif attn_type_map[i] == 0:
            mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]] = True
        elif attn_type_map[i] == 2:
            mask_slice = mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]]
            short_len = min(mask_slice.shape[0], mask_slice.shape[1])
            inv_causal_part = torch.ones(short_len, short_len, device=mask_slice.device, dtype=mask_slice.dtype).triu_()
            mask_slice[:short_len, :short_len] = inv_causal_part
            mask_slice[:, short_len:] = True
            mask_slice[short_len:, :] = False
        else:
            mask_slice_causal = mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]].clone()
            short_len = min(mask_slice_causal.shape[0], mask_slice_causal.shape[1])
            causal_part = torch.ones(short_len, short_len, device=mask_slice_causal.device, dtype=mask_slice_causal.dtype).tril_()
            mask_slice_causal[-short_len:, -short_len:] = causal_part
            mask_slice_causal[:, :-short_len] = True
            mask_slice_causal[:-short_len, :] = False
            
            mask_slice_inv_causal = mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]].clone()
            short_len = min(mask_slice_inv_causal.shape[0], mask_slice_inv_causal.shape[1])
            inv_causal_part = torch.ones(short_len, short_len, device=mask_slice_inv_causal.device, dtype=mask_slice_inv_causal.dtype).triu_()
            mask_slice_inv_causal[:short_len, :short_len] = inv_causal_part
            mask_slice_inv_causal[:, short_len:] = True
            mask_slice_inv_causal[short_len:, :] = False

            mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]] = torch.logical_and(mask_slice_causal, mask_slice_inv_causal)

                
    return mask


In [45]:
get_mask_from_ranges(torch.tensor([[0, 4]], device='cuda'), torch.tensor([[0, 10]], device='cuda'), torch.tensor([3], device='cuda'), 4, 10)

tensor([[ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [False,  True,  True,  True,  True,  True,  True,  True, False, False],
        [False, False,  True,  True,  True,  True,  True,  True,  True, False],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True]],
       device='cuda:0')

In [1]:
from contextlib import contextmanager

import torch
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention.flex_attention import flex_attention
flex_attention_compiled = torch.compile(flex_attention)

from flash_attn_interface import flex_flash_attn_func

def get_mask_from_ranges(q_ranges, k_ranges, attn_type_map, q_len, k_len):
    bsz = q_ranges.shape[0]
    mask = torch.zeros((q_len, k_len), device='cuda', dtype=torch.bool)
    for i in range(bsz):
        if attn_type_map[i] == 1:
            mask_slice = mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]]
            short_len = min(mask_slice.shape[0], mask_slice.shape[1])
            causal_part = torch.ones(short_len, short_len, device=mask_slice.device, dtype=mask_slice.dtype).tril_()
            mask_slice[-short_len:, -short_len:] = causal_part
            mask_slice[:, :-short_len] = True
            mask_slice[:-short_len, :] = False
        elif attn_type_map[i] == 0:
            mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]] = True
        elif attn_type_map[i] == 2:
            mask_slice = mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]]
            short_len = min(mask_slice.shape[0], mask_slice.shape[1])
            inv_causal_part = torch.ones(short_len, short_len, device=mask_slice.device, dtype=mask_slice.dtype).triu_()
            mask_slice[:short_len, :short_len] = inv_causal_part
            mask_slice[:, short_len:] = True
            mask_slice[short_len:, :] = False
        else:
            mask_slice_causal = mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]].clone()
            short_len = min(mask_slice_causal.shape[0], mask_slice_causal.shape[1])
            causal_part = torch.ones(short_len, short_len, device=mask_slice_causal.device, dtype=mask_slice_causal.dtype).tril_()
            mask_slice_causal[-short_len:, -short_len:] = causal_part
            mask_slice_causal[:, :-short_len] = True
            mask_slice_causal[:-short_len, :] = False
            
            mask_slice_inv_causal = mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]].clone()
            short_len = min(mask_slice_inv_causal.shape[0], mask_slice_inv_causal.shape[1])
            inv_causal_part = torch.ones(short_len, short_len, device=mask_slice_inv_causal.device, dtype=mask_slice_inv_causal.dtype).triu_()
            mask_slice_inv_causal[:short_len, :short_len] = inv_causal_part
            mask_slice_inv_causal[:, short_len:] = True
            mask_slice_inv_causal[short_len:, :] = False

            mask[q_ranges[i, 0]:q_ranges[i, 1], k_ranges[i, 0]:k_ranges[i, 1]] = torch.logical_and(mask_slice_causal, mask_slice_inv_causal)

    return mask


@contextmanager
def time_with_cuda_event(name, flops):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.nvtx.range_push(name)
    start.record()
    yield
    end.record()
    torch.cuda.nvtx.range_pop()
    end.synchronize()

    elapsed_time = start.elapsed_time(end)
    mfu = flops / (elapsed_time * 0.989 * 1e12)
    print(f"{name} took {elapsed_time} ms, mfu: {mfu:.2f}")

query = torch.randn(1, 48, 4096, 128, device='cuda', dtype=torch.bfloat16)
key = torch.randn(1, 48, 4096, 128, device='cuda', dtype=torch.bfloat16)
value = torch.randn(1, 48, 4096, 128, device='cuda', dtype=torch.bfloat16)

query_thd = query.squeeze().transpose(0, 1)
key_thd = key.squeeze().transpose(0, 1)
value_thd = value.squeeze().transpose(0, 1)
print(query_thd.shape, key_thd.shape, value_thd.shape)

query_thd.shape, key_thd.shape, value_thd.shape
q_ranges = torch.tensor([[0, 1024], [1024, 4096]], device='cuda', dtype=torch.int32)
k_ranges = torch.tensor([[0, 1024], [0, 4096]], device='cuda', dtype=torch.int32)
attn_type_map = torch.tensor([1, 3], device='cuda', dtype=torch.int32)

mask = get_mask_from_ranges(q_ranges, k_ranges, attn_type_map, 4096, 4096)

warmup_iters = 10
test_iters = 100

torch.Size([4096, 48, 128]) torch.Size([4096, 48, 128]) torch.Size([4096, 48, 128])


In [2]:
SLIDING_WINDOW = 1024

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask

# Because the sparsity pattern is independent of batch and heads, we'll set them to None (which broadcasts them) 
block_mask = create_block_mask(sliding_window_causal, B=None, H=None, Q_LEN=4096, KV_LEN=4096)
# In this case, we don't need a score_mod, so we won't pass any in.
# However, score_mod can still be combined with block_mask if you need the additional flexibility.

In [3]:
# swa
flops = 6144 * 4 * (1024 * 1024 / 2 + 1024 * 3072)
total_flops = flops * test_iters
for i in range(warmup_iters):
    o = flex_attention_compiled(query, key, value, block_mask=block_mask)

with time_with_cuda_event("flex_attention", total_flops):
    for i in range(test_iters):
        o = flex_attention_compiled(query, key, value, block_mask=block_mask)

for i in range(warmup_iters):
    o_thd, _ = flex_flash_attn_func(query_thd, key_thd, value_thd, q_ranges, k_ranges, max_seqlen_q=3072, max_seqlen_k=4096, attn_type_map=attn_type_map)

with time_with_cuda_event("flex_flash_attn", total_flops):
    for i in range(test_iters):
        o_thd, _ = flex_flash_attn_func(query_thd, key_thd, value_thd, q_ranges, k_ranges, max_seqlen_q=3072, max_seqlen_k=4096, attn_type_map=attn_type_map)

flex_attention took 35.26921463012695 ms, mfu: 0.26
flex_flash_attn took 22.082239151000977 ms, mfu: 0.41


In [3]:
from torch.nn.attention.flex_attention import create_block_mask

def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Because the sparsity pattern is independent of batch and heads, we'll set them to None (which broadcasts them) 
block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=4096, KV_LEN=4096)
q_ranges = torch.tensor([[0, 4096]], device='cuda', dtype=torch.int32)
k_ranges = torch.tensor([[0, 4096]], device='cuda', dtype=torch.int32)
attn_type_map = torch.tensor([1], device='cuda', dtype=torch.int32)

mask = get_mask_from_ranges(q_ranges, k_ranges, attn_type_map, 4096, 4096)

In [None]:
# causal
flops = 6144 * 4 * (4096 * 4096 / 2)
total_flops = flops * test_iters
for i in range(warmup_iters):
    o = flex_attention_compiled(query, key, value, block_mask=block_mask)

with time_with_cuda_event("flex_attention", total_flops):
    for i in range(test_iters):
        o = flex_attention_compiled(query, key, value, block_mask=block_mask)

for i in range(warmup_iters):
    o_thd, _ = flex_flash_attn_func(query_thd, key_thd, value_thd, q_ranges, k_ranges, max_seqlen_q=3072, max_seqlen_k=4096, attn_type_map=attn_type_map)

with time_with_cuda_event("flex_flash_attn", total_flops):
    for i in range(test_iters):
        o_thd, _ = flex_flash_attn_func(query_thd, key_thd, value_thd, q_ranges, k_ranges, max_seqlen_q=3072, max_seqlen_k=4096, attn_type_map=attn_type_map)

flex_attention took 69.78524780273438 ms, mfu: 0.30
flex_flash_attn took 37.469600677490234 ms, mfu: 0.56


In [2]:
q_ranges = torch.tensor([[0, 4096]], device='cuda', dtype=torch.int32)
k_ranges = torch.tensor([[0, 4096]], device='cuda', dtype=torch.int32)
attn_type_map = torch.tensor([0], device='cuda', dtype=torch.int32)

mask = get_mask_from_ranges(q_ranges, k_ranges, attn_type_map, 4096, 4096)

In [12]:
# full
flops = 6144 * 4 * (4096 * 4096)
total_flops = flops * test_iters
for i in range(warmup_iters):
    o = flex_attention_compiled(query, key, value)

with time_with_cuda_event("flex_attention", total_flops):
    for i in range(test_iters):
        o = flex_attention_compiled(query, key, value)

for i in range(warmup_iters):
    o_thd, _ = flex_flash_attn_func(query_thd, key_thd, value_thd, q_ranges, k_ranges, max_seqlen_q=3072, max_seqlen_k=4096, attn_type_map=attn_type_map)

with time_with_cuda_event("flex_flash_attn", total_flops):
    for i in range(test_iters):
        o_thd, _ = flex_flash_attn_func(query_thd, key_thd, value_thd, q_ranges, k_ranges, max_seqlen_q=3072, max_seqlen_k=4096, attn_type_map=attn_type_map)

flex_attention took 110.56114959716797 ms, mfu: 0.38
flex_flash_attn took 65.88972473144531 ms, mfu: 0.63
