In [9]:
import torch
from einops import rearrange
from torch.nn import Linear, Module
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

# --- Flex Attention Utils ---


def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


def generate_doc_mask_mod(mask_mod, document_id):
    # can feed in another mask modifier function such as `causal`

    # Get unique document IDs and their counts
    _, counts = torch.unique_consecutive(document_id, return_counts=True)
    # Create cumulative counts (offsets)
    offsets = torch.cat(
        [torch.tensor([0], device=document_id.device), counts.cumsum(0)[:-1]]
    )

    if mask_mod is not None:

        def doc_mask_wrapper(b, h, q_idx, kv_idx):
            same_doc = document_id[q_idx] == document_id[kv_idx]
            q_logical = q_idx - offsets[document_id[q_idx]]
            kv_logical = kv_idx - offsets[document_id[kv_idx]]
            inner_mask = mask_mod(b, h, q_logical, kv_logical)
            return same_doc & inner_mask

        return doc_mask_wrapper

    else:

        def doc_mask_wrapper_solo(b, h, q_idx, kv_idx):
            same_doc = document_id[q_idx] == document_id[kv_idx]
            q_logical = q_idx - offsets[document_id[q_idx]]
            kv_logical = kv_idx - offsets[document_id[kv_idx]]
            return same_doc

        return doc_mask_wrapper_solo

In [10]:
debug_info = torch.load("debug_info.pth")

In [11]:
mdebug_info = torch.load("mdebug_info.pth")

In [12]:
doc_ids_down, x_down_unsqueezed = debug_info["doc_ids_down"], debug_info["x_down_unsqueezed"]

In [13]:
mdoc_ids, x_down = mdebug_info["doc_ids_down"], mdebug_info["x_down"]

In [14]:
doc_ids_down.shape, x_down_unsqueezed.shape, mdoc_ids.shape, x_down.shape

(torch.Size([16761]),
 torch.Size([1, 16761, 64]),
 torch.Size([1, 20992]),
 torch.Size([1, 20992, 768]))

In [16]:
mdoc_ids.view(-1).shape

torch.Size([20992])

In [5]:
some_x = torch.randn(1, doc_ids_down.shape[0], 768, device="cuda")

In [6]:
mask_mod = generate_doc_mask_mod(None, doc_ids_down)

In [7]:
block_mask = create_block_mask(
    mask_mod,
    B=None,
    H=None,
    Q_LEN=x_down_unsqueezed.shape[1],
    KV_LEN=x_down_unsqueezed.shape[1],
    device=x_down_unsqueezed.device,
)

In [8]:
x_down_unsqueezed.shape

torch.Size([1, 16761, 64])