# Attention Mask Exploration

This notebook demonstrates the various attention mask utilities and their visualizations.

In [None]:
import torch

from zmaj_lm.utils.masks import (
    create_block_diagonal_mask,
    create_causal_mask,
    create_decoder_mask,
    create_padding_mask,
    create_sliding_window_mask,
    mask_to_bias,
)
from zmaj_lm.utils.visualize import plot_attention_mask, plot_mask_comparison

## 1. Causal Mask

Prevents attention to future positions (lower triangular).

In [None]:
seq_len = 16
causal = create_causal_mask(seq_len, device="cpu")
fig = plot_attention_mask(causal, title="Causal Mask")

## 2. Padding Mask

Masks out padding tokens in sequences of different lengths.

In [None]:
# Batch of 4 sequences with different lengths
lengths = torch.tensor([16, 12, 8, 14])
max_len = 16
padding = create_padding_mask(lengths, max_len)

print(f"Padding mask shape: {padding.shape}")
print(f"Lengths: {lengths.tolist()}")

# Combine padding with causal mask for decoder attention
decoder_mask_seq0 = create_decoder_mask(max_len, device="cpu", attention_mask=padding[0:1])
decoder_mask_seq2 = create_decoder_mask(max_len, device="cpu", attention_mask=padding[2:3])

# Compare: no padding vs heavy padding
fig = plot_mask_comparison(
    masks=[decoder_mask_seq0, decoder_mask_seq2],
    titles=[f"Causal + Padding (len={lengths[0]})", f"Causal + Padding (len={lengths[2]})"],
)

## 3. Block Diagonal Mask

Prevents attention across document boundaries in packed sequences.

In [None]:
# Create a sequence with 4 documents of different lengths
doc_ids = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3]])

block_diag = create_block_diagonal_mask(doc_ids)
print(f"Block diagonal mask shape: {block_diag.shape}")
fig = plot_attention_mask(block_diag, title="Block Diagonal Mask")

## 4. Combined Decoder Mask

Combines causal masking with padding or block-diagonal masking.

In [None]:
# Causal + block diagonal for packed sequences
combined = create_decoder_mask(seq_len, device="cpu", attention_mask=block_diag)

fig = plot_mask_comparison(
    masks=[causal, block_diag, combined],
    titles=["Causal Only", "Block Diagonal Only", "Combined"],
)

## 5. Experiment with Different Patterns

Try your own document patterns!

In [None]:
# Example: Two long documents and two short ones
doc_ids_custom = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2]])

custom_mask = create_block_diagonal_mask(doc_ids_custom)
custom_combined = create_decoder_mask(seq_len, device="cpu", attention_mask=custom_mask)

# Generate random attention scores (simulating QK^T output)
torch.manual_seed(42)
attention_scores = torch.randn(1, seq_len, seq_len)

# Apply mask to attention scores (convert mask to bias)
mask_bias = mask_to_bias(custom_combined, dtype=torch.float32)
masked_scores = attention_scores + mask_bias

# Apply softmax to get final attention weights
attention_weights = torch.softmax(masked_scores, dim=-1)

fig = plot_mask_comparison(
    masks=[attention_scores, masked_scores, attention_weights],
    titles=["Raw Scores", "Masked Scores", "Attention Weights (after softmax)"],
)

## 6. Sliding Window Attention

Restricts attention to a local window of nearby tokens. Used in models like Mistral and Longformer for efficient long-context processing.

In [None]:
# Create sliding window masks with different window sizes
seq_len = 32
window_sizes = [4, 8, 16]

sliding_masks = []
titles = []

for ws in window_sizes:
    mask = create_sliding_window_mask(seq_len, ws, device="cpu", causal=True)
    sliding_masks.append(mask)
    titles.append(f"Sliding Window (size={ws})")

# Compare different window sizes
fig = plot_mask_comparison(masks=sliding_masks, titles=titles)

### Causal vs Bidirectional Sliding Window

Causal sliding window (like Mistral) only attends to past tokens within the window.
Bidirectional (like Longformer) attends to both past and future tokens within the window.

In [None]:
window_size = 6
seq_len = 20

# Compare causal vs bidirectional with same window size
full_causal = create_causal_mask(seq_len, device="cpu")
sliding_causal = create_sliding_window_mask(seq_len, window_size, device="cpu", causal=True)
sliding_bidir = create_sliding_window_mask(seq_len, window_size, device="cpu", causal=False)

fig = plot_mask_comparison(
    masks=[full_causal, sliding_causal, sliding_bidir],
    titles=[
        "Full Causal",
        f"Causal Window (size={window_size})",
        f"Bidirectional Window (size={window_size})",
    ],
)

print("Full causal attention: Each token attends to all previous tokens")
print(f"Causal sliding window: Each token attends to up to {window_size} previous tokens")
print(f"Bidirectional sliding window: Each token attends to Â±{window_size} nearby tokens")

### Sliding Window with Padding

Combining sliding window attention with padding masks - useful for batched training with variable-length sequences.

In [None]:
# Batch with different sequence lengths
seq_len = 24
window_size = 8
lengths = torch.tensor([24, 16, 12])
padding_mask = create_padding_mask(lengths, seq_len)

# Create decoder masks with sliding window for different sequences
full_len_sw = create_decoder_mask(
    seq_len, device="cpu", attention_mask=padding_mask[0:1], window_size=window_size
)
medium_len_sw = create_decoder_mask(
    seq_len, device="cpu", attention_mask=padding_mask[1:2], window_size=window_size
)
short_len_sw = create_decoder_mask(
    seq_len, device="cpu", attention_mask=padding_mask[2:3], window_size=window_size
)

fig = plot_mask_comparison(
    masks=[full_len_sw, medium_len_sw, short_len_sw],
    titles=[
        f"Window {window_size}, len={lengths[0]} (no padding)",
        f"Window {window_size}, len={lengths[1]} (some padding)",
        f"Window {window_size}, len={lengths[2]} (heavy padding)",
    ],
)

print(f"Window size: {window_size}")
print(f"Sequence lengths: {lengths.tolist()}")
print("\nNote: Padding positions (grayed out) cannot be attended to,")
print("even if they fall within the sliding window.")