# Attention Mask Exploration

This notebook demonstrates the various attention mask utilities and their visualizations.

In [None]:
import torch

from zmaj_lm.utils.masks import (
    create_block_diagonal_mask,
    create_causal_mask,
    create_decoder_mask,
    create_padding_mask,
    mask_to_bias,
)
from zmaj_lm.utils.visualize import plot_attention_mask, plot_mask_comparison

## 1. Causal Mask

Prevents attention to future positions (lower triangular).

In [None]:
seq_len = 16
causal = create_causal_mask(seq_len, device="cpu")
fig = plot_attention_mask(causal, title="Causal Mask")

## 2. Padding Mask

Masks out padding tokens in sequences of different lengths.

In [None]:
# Batch of 4 sequences with different lengths
lengths = torch.tensor([16, 12, 8, 14])
max_len = 16
padding = create_padding_mask(lengths, max_len)

print(f"Padding mask shape: {padding.shape}")
print(f"Lengths: {lengths.tolist()}")

# Combine padding with causal mask for decoder attention
decoder_mask_seq0 = create_decoder_mask(max_len, device="cpu", attention_mask=padding[0:1])
decoder_mask_seq2 = create_decoder_mask(max_len, device="cpu", attention_mask=padding[2:3])

# Compare: no padding vs heavy padding
fig = plot_mask_comparison(
    masks=[decoder_mask_seq0, decoder_mask_seq2],
    titles=[f"Causal + Padding (len={lengths[0]})", f"Causal + Padding (len={lengths[2]})"],
)

## 3. Block Diagonal Mask

Prevents attention across document boundaries in packed sequences.

In [None]:
# Create a sequence with 4 documents of different lengths
doc_ids = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3]])

block_diag = create_block_diagonal_mask(doc_ids)
print(f"Block diagonal mask shape: {block_diag.shape}")
fig = plot_attention_mask(block_diag, title="Block Diagonal Mask")

## 4. Combined Decoder Mask

Combines causal masking with padding or block-diagonal masking.

In [None]:
# Causal + block diagonal for packed sequences
combined = create_decoder_mask(seq_len, device="cpu", attention_mask=block_diag)

fig = plot_mask_comparison(
    masks=[causal, block_diag, combined],
    titles=["Causal Only", "Block Diagonal Only", "Combined"],
)

## 5. Experiment with Different Patterns

Try your own document patterns!

In [None]:
# Example: Two long documents and two short ones
doc_ids_custom = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2]])

custom_mask = create_block_diagonal_mask(doc_ids_custom)
custom_combined = create_decoder_mask(seq_len, device="cpu", attention_mask=custom_mask)

# Generate random attention scores (simulating QK^T output)
torch.manual_seed(42)
attention_scores = torch.randn(1, seq_len, seq_len)

# Apply mask to attention scores (convert mask to bias)
mask_bias = mask_to_bias(custom_combined, dtype=torch.float32)
masked_scores = attention_scores + mask_bias

# Apply softmax to get final attention weights
attention_weights = torch.softmax(masked_scores, dim=-1)

fig = plot_mask_comparison(
    masks=[attention_scores, masked_scores, attention_weights],
    titles=["Raw Scores", "Masked Scores", "Attention Weights (after softmax)"],
)