In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
from attn_gym import visualize_attention_scores
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.attention.flex_attention import _mask_mod_signature
from torch.nn.attention.flex_attention import (
    _score_mod_signature,
    _mask_mod_signature,
    _vmap_for_bhqkv,
    _ModificationType,
)
from torch.nn.utils.rnn import pad_sequence
from typing import List, Union
from torch import Tensor
import torch
import random

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

def _offsets_to_doc_ids_tensor(offsets):
    device = offsets.device
    offsets = offsets[offsets != -1]
    counts = offsets[1:] - offsets[:-1]
    return torch.repeat_interleave(
        torch.arange(len(counts), device=device, dtype=torch.int32), counts
    )


def length_to_offsets(lengths: List[int], device: Union[str, torch.device]) -> Tensor:
    """Converts a list of lengths to a list of offsets.

    Args:
        lengths: A list of lengths.

    """
    offsets = [0]
    offsets.extend(lengths)
    offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
    offsets = torch.cumsum(offsets, dim=-1)
    return offsets


def generate_doc_mask_mod(mask_mod: _mask_mod_signature, offsets: Tensor) -> _mask_mod_signature:
    """Generates mask mods that apply to inputs to flex attention in the sequence stacked
    format.

    Args:
        mask_mod: The mask mod to apply to the documents
        offsets: This tensor should be of shape(num_documents + 1)
            this should contain the cumulative counts of document tokens.
            e.g. if you have 3 documents of length 2, 4, 3 then
            offsets = [0, 2, 6, 9]

    Note:
        What is the sequence stacked format? When assembling batches of inputs, we
        take multiple sequences and stack them together to form 1 large sequence. We then
        use masking to ensure that the attention scores are only applied to tokens within
        the same document.
    """
    offsets = pad_sequence(offsets, batch_first = True, padding_value = -1)
    document_ids = [_offsets_to_doc_ids_tensor(offsets[i]) for i in range(offsets.shape[0])]
    document_ids = torch.stack(document_ids, 0)

    def doc_mask_mod(b, h, q_idx, kv_idx):
        same_doc = document_ids[b][q_idx] == document_ids[b][kv_idx]
        q_logical = q_idx - offsets[b][document_ids[b][q_idx]]
        kv_logical = kv_idx - offsets[b][document_ids[b][kv_idx]]
        inner_mask = mask_mod(b, h, q_logical, kv_logical)
        return same_doc & inner_mask

    return doc_mask_mod

def generate_single_doc_mask_mod(docs):
    def document_causal_mask(b, h, q_idx, kv_idx):
        causal_mask = q_idx >= kv_idx
        document_mask = docs[q_idx] == docs[kv_idx]
        return causal_mask & document_mask
    
    return document_causal_mask

In [3]:
device = 'cpu'

In [4]:
def generate_random_lengths(total_length, num_documents):
    # Initialize all lengths to 1 to ensure each document has at least one token
    lengths = [1] * num_documents
    remaining_length = total_length - num_documents

    # Randomly distribute the remaining length
    for _ in range(remaining_length):
        index = random.randint(0, num_documents - 1)
        lengths[index] += 1

    return lengths

max_seq_len, doc_count = 21, 4
B, H, SEQ_LEN, HEAD_DIM = 1, 1, max_seq_len, 8

offsets = []
for i in range(doc_count, doc_count + B, 1):
    lengths = generate_random_lengths(max_seq_len, i)

    offsets.append(length_to_offsets(lengths, device))

def make_tensor():
    return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)

q, k, v = make_tensor(), make_tensor(), make_tensor()

In [6]:
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets[:1])

visualize_attention_scores(
    q,
    k,
    mask_mod=document_causal_mask,
    device=device,
    name="document_causal_mask",
)

Visualization saved as document_causal_mask.png


In [8]:
from IPython.display import Image
Image(url= "document_causal_mask.png")

In [11]:
document_causal_mask = generate_single_doc_mask_mod(_offsets_to_doc_ids_tensor(offsets[0]))

visualize_attention_scores(
    q,
    k,
    mask_mod=document_causal_mask,
    device=device,
    name="document_causal_mask",
)

Visualization saved as document_causal_mask.png


In [12]:
from IPython.display import Image
Image(url= "document_causal_mask.png")