In [7]:
from xformers.ops.fmha.attn_bias import (
    BlockDiagonalCausalMask,
    BlockDiagonalCausalWithOffsetPaddedKeysMask,
    BlockDiagonalMask,
)

import pandas as pd

In [8]:
# Define coloring scheme for visualization
highlight_dict = {0.0: '#68A357', float('-inf'): '#C97064'}

def highlight_cells(value):
    return f'background-color: {highlight_dict.get(value, "")}'

In [9]:
# Example 1: BlockDiagonalCausalMask with Sliding Window

# Example prompts: 4, 6, and 5 tokens respectively
sequence_lengths = [4, 6, 5]
window_size = 2

causal_mask = BlockDiagonalCausalMask.from_seqlens(sequence_lengths).make_local_attention(window_size)

batch_dim = 1
total_tokens = sum(sequence_lengths)
mask_tensor = causal_mask.materialize((batch_dim, total_tokens, total_tokens))

df_mask = pd.DataFrame(mask_tensor[0].numpy())
df_mask.style.applymap(highlight_cells)

  df_mask.style.applymap(highlight_cells)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
1,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
2,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
3,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
4,-inf,-inf,-inf,-inf,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
5,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
6,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
7,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf
8,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf
9,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf


In [10]:
# Example 2: BlockDiagonalMask for prefill (queries and KV lengths differ)

query_lengths = [2, 4]
kv_lengths = [8, 5]   # For example: first prompt = 2 queries, 8 kv; second = 4 queries, 5 kv
window_size = 2

bd_mask = BlockDiagonalMask.from_seqlens(query_lengths, kv_lengths).make_local_attention_from_bottomright(window_size)

batch_dim = 1
q_total = sum(query_lengths)
kv_total = sum(kv_lengths)

mask_tensor = bd_mask.materialize((batch_dim, q_total, kv_total))

df_mask = pd.DataFrame(mask_tensor[0].numpy())
df_mask.style.applymap(highlight_cells)

  df_mask.style.applymap(highlight_cells)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12
0,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf
1,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf
2,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf
3,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf
4,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf
5,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0


In [11]:
# Example 3: BlockDiagonalCausalWithOffsetPaddedKeysMask for KV Cache Handling

q_len_list = [1, 1]
kv_actual_lengths = [4, 6]
kv_tensor_capacity = 8  # Fixed size tensor used in backend

offset_mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
    q_seqlen=q_len_list,
    kv_padding=kv_tensor_capacity,
    kv_seqlen=kv_actual_lengths
)

batch_dim = 1
q_sum = sum(q_len_list)
kv_tensor_size = kv_tensor_capacity * len(kv_actual_lengths)

mask_tensor = offset_mask.materialize((batch_dim, q_sum, kv_tensor_size))

df_mask = pd.DataFrame(mask_tensor[0].numpy())
df_mask.style.applymap(highlight_cells)

  df_mask.style.applymap(highlight_cells)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
0,0.0,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
1,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,0.0,0.0,0.0,-inf,-inf
