In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import torch
torch.__version__

'2.6.0.dev20241212+cu124'

In [3]:
from attn_gym import visualize_attention_scores
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.attention.flex_attention import _mask_mod_signature
from torch.nn.attention.flex_attention import (
    _score_mod_signature,
    _mask_mod_signature,
    _vmap_for_bhqkv,
    _ModificationType,
)
from torch.nn.utils.rnn import pad_sequence
from typing import List, Union
from torch import Tensor
import torch
import random

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

def _offsets_to_doc_ids_tensor(offsets):
    device = offsets.device
    offsets = offsets[offsets != -1]
    counts = offsets[1:] - offsets[:-1]
    return torch.repeat_interleave(
        torch.arange(len(counts), device=device, dtype=torch.int32), counts
    )


def length_to_offsets(lengths: List[int], device: Union[str, torch.device]) -> Tensor:
    """Converts a list of lengths to a list of offsets.

    Args:
        lengths: A list of lengths.

    """
    offsets = [0]
    offsets.extend(lengths)
    offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
    offsets = torch.cumsum(offsets, dim=-1)
    return offsets

def generate_doc_mask_mod(offsets):
    
    offsets = pad_sequence(offsets, batch_first = True, padding_value = -1)
    docs = [_offsets_to_doc_ids_tensor(offsets[i]) for i in range(offsets.shape[0])]
    docs = torch.stack(docs, 0)
    
    def document_causal_mask(b, h, q_idx, kv_idx):
        causal_mask = q_idx >= kv_idx
        document_mask = docs[b, q_idx] == docs[b, kv_idx]
        return causal_mask & document_mask
    
    return document_causal_mask

In [4]:
flex_attention = torch.compile(flex_attention, dynamic = False)
create_block_mask = torch.compile(create_block_mask, dynamic = False)

In [5]:
device = 'cuda'

In [6]:
def generate_random_lengths(total_length, num_documents):
    # Initialize all lengths to 1 to ensure each document has at least one token
    lengths = [1] * num_documents
    remaining_length = total_length - num_documents

    # Randomly distribute the remaining length
    for _ in range(remaining_length):
        index = random.randint(0, num_documents - 1)
        lengths[index] += 1

    return lengths

max_seq_len, doc_count = 21, 4
B, H, SEQ_LEN, HEAD_DIM = 2, 1, max_seq_len, 128

offsets = []
for i in range(doc_count, doc_count + B, 1):
    lengths = generate_random_lengths(max_seq_len, i)

    offsets.append(length_to_offsets(lengths, device))

def make_tensor():
    return torch.randn(B, H, SEQ_LEN, HEAD_DIM, device=device)

q, k, v = make_tensor(), make_tensor(), make_tensor()

In [7]:
from functools import lru_cache

In [8]:
def forward(q, k, v, offsets):
    document_causal_mask = generate_doc_mask_mod(offsets)
    block_mask = create_block_mask(document_causal_mask, None, None, SEQ_LEN, SEQ_LEN, device, _compile = True)
    flex = flex_attention(q, k, v, block_mask=block_mask)
    return flex

In [9]:
%%time

flex = forward(q, k, v, offsets)

CPU times: user 1.12 s, sys: 107 ms, total: 1.23 s
Wall time: 1.29 s


In [10]:
offsets = []
extra = 7
for i in range(doc_count, doc_count + B + extra, extra):
    lengths = generate_random_lengths(max_seq_len, i)

    offsets.append(length_to_offsets(lengths, device))
offsets

[tensor([ 0,  9, 12, 16, 21], device='cuda:0'),
 tensor([ 0,  3,  6,  7, 10, 11, 12, 14, 15, 16, 19, 21], device='cuda:0')]

In [11]:
%%time

flex = forward(q, k, v, offsets)

CPU times: user 920 µs, sys: 0 ns, total: 920 µs
Wall time: 848 µs


In [12]:
def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16):
    total_size = sum(mask.size(0) for mask in masks)
    combined_mask = torch.zeros(total_size, total_size, dtype=dtype)

    current_pos = 0

    for mask in masks:
        size = mask.size(0)
        combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
        current_pos += size

    min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
    inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value)
    return inverted_mask.unsqueeze(0)

In [13]:
%%time

masks = []
for f in offsets:
    masks_ = []
    masking = torch.diff(f)
    for m in masking:
        masks_.append(torch.tril(torch.ones(m, m)))
    
    masks.append(block_diagonal_concat_inverted(*masks_, dtype = q.dtype))
    
masks = torch.stack(masks, 0).to('cuda')
sdpa = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask = masks)

CPU times: user 7.72 ms, sys: 0 ns, total: 7.72 ms
Wall time: 2.42 ms


In [14]:
sdpa.argmax(-1)

tensor([[[ 86,   6,  49,  49, 124, 124,  43,  49,  49,   4, 114,  34,   3,   3,
           67,  67,  14,  16,  96, 117,  16]],

        [[ 28,  73, 101, 100,  11,  75,   5,  67,  79,  74, 125, 114, 113,  34,
           10,  43,  64,  89,  93, 119,  40]]], device='cuda:0')

In [15]:
flex.argmax(-1)

tensor([[[ 86,   6,  49,  49, 124, 124,  43,  49,  49,   4, 114,  34,   3,   3,
           67,  67,  14,  16,  96, 117,  16]],

        [[ 28,  73, 101, 100,  11,  75,   5,  67,  79,  74, 125, 114, 113,  34,
           10,  43,  64,  89,  93, 119,  40]]], device='cuda:0')

In [16]:
flex.argmax(-1) == sdpa.argmax(-1)

tensor([[[True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True]],

        [[True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True]]],
       device='cuda:0')