# Attention Heat Map

In [None]:
from pathlib import Path
from typing import Literal

import ipywidgets as widgets
from ipywidgets import interact
from jaxtyping import Float
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from numpy.typing import NDArray
import plotly.express as px
import scienceplots  # noqa: F401
import torch

In [None]:
data_path = "/data/shared_data/position-bias/results-final-fineweb-edu"
model_name: Literal[
    "anas-awadalla_mpt-7b",
    "bigscience_bloom",
    "bigscience_bloom-7b1",
    "tiiuae_falcon-rw-7b",
    "eluzhnica_mpt-30b-peft-compatible",
] = "bigscience_bloom"
sequence_length = 64

data_dir = Path(data_path) / model_name / str(sequence_length)
mean_attention_matrix: Float[torch.Tensor, "n_layers n_heads seq_len seq_len"] = (
    torch.load(data_dir / "qk_mean.pt")
)
var_attention_matrix: Float[torch.Tensor, "n_layers n_heads seq_len seq_len"] = (
    torch.load(data_dir / "qk_var.pt")
)

assert mean_attention_matrix.shape == var_attention_matrix.shape

n_layers, n_heads, seq_len, _ = mean_attention_matrix.shape

result_dir = Path("../../results/attention")
result_dir.mkdir(parents=True, exist_ok=True)

In [None]:
def get_stats(matrix: Float[NDArray, "seq_len seq_len"]) -> dict[str, float]:
    """Compute various statistics of a given matrix."""
    diag = np.diagonal(matrix)
    rows, cols = np.indices(matrix.shape)
    dist = np.abs(rows - cols)

    abs_matrix = np.abs(matrix)
    avg_dist = (
        np.sum(dist * abs_matrix) / np.sum(abs_matrix) if np.sum(abs_matrix) > 0 else 0
    )

    return {
        "Min": matrix.min(),
        "Max": matrix.max(),
        "Mean": matrix.mean(),
        "Std": matrix.std(),
        "Median": np.median(matrix),
        "L2 Norm": np.linalg.norm(matrix).item(),
        "P99": np.percentile(matrix, 99),
        "Diag Mean": diag.mean(),
        "Avg Dist": avg_dist,
    }


def format_stats(stats: dict[str, float]) -> str:
    """Format statistics dictionary into a two-line string."""
    items = [
        f"{k}: {v:.4e}" if k != "Avg Dist" else f"{k}: {v:.2f}"
        for k, v in stats.items()
    ]
    line1 = " | ".join(items[:5])
    line2 = " | ".join(items[5:])
    return f"{line1}<br>{line2}"

In [None]:
current_mean_state: tuple[int, int, bool] = 0, 0, True


def plot_mean(layer: int, head: int, mask_upper: bool) -> None:
    """Plot mean attention matrix for given layer and head."""
    global current_mean_state  # noqa: PLW0603
    matrix = mean_attention_matrix[layer, head].numpy().copy()
    stats = get_stats(matrix)
    stats_str = format_stats(stats)

    # Mask upper triangle
    if mask_upper:
        mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
        matrix[mask] = np.nan

    fig = px.imshow(
        matrix,
        color_continuous_scale="Viridis",
        title=f"Attention Matrix Mean (Layer {layer}, Head {head})<br><sup>{stats_str}</sup>",
        labels={"x": "Key Position", "y": "Query Position", "color": "Attention"},
    )
    fig.update_layout(width=800, height=850, margin={"t": 120})
    current_mean_state = layer, head, mask_upper
    fig.show()


interact(
    plot_mean,
    layer=widgets.IntSlider(min=0, max=n_layers - 1, step=1, value=0),
    head=widgets.IntSlider(min=0, max=n_heads - 1, step=1, value=0),
    mask_upper=widgets.Checkbox(value=True, description="Mask Upper Triangle"),
)


def export_mean_png(_b: widgets.Button) -> None:
    """Export the current mean attention matrix plot as a PNG file."""
    layer, head, mask_upper = current_mean_state
    matrix = mean_attention_matrix[layer, head].numpy().copy()
    if mask_upper:
        mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
        matrix[mask] = np.nan

    plt.style.use(["bright", "no-latex"])
    plt.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["DejaVu Sans", "Liberation Sans", "Arial"],
            "mathtext.fontset": "dejavusans",
            "font.size": 25,
            "axes.labelsize": 25,
            "xtick.labelsize": 20,
            "ytick.labelsize": 20,
        },
    )
    file_name = (
        f"attention_heatmap_mean_{model_name}_{sequence_length}_{layer}_{head}.png"
    )
    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(matrix, cmap="viridis")
    ax.set_xlabel("Key Position")
    ax.set_ylabel("Query Position")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="4%", pad=0.1)
    fig.colorbar(im, cax=cax, label="Attention")
    fig.tight_layout()
    fig.savefig(result_dir / file_name, dpi=400, bbox_inches="tight")
    plt.close(fig)
    print(f"Exported to {result_dir / file_name}")


button_png = widgets.Button(description="Export PNG")
button_png.on_click(export_mean_png)

display(widgets.HBox([button_png]))

In [None]:
current_var_state: tuple[int, int, bool] = 0, 0, True


def plot_variance(layer: int, head: int, mask_upper: bool) -> None:
    """Plot variance attention matrix for given layer and head."""
    global current_var_state  # noqa: PLW0603
    matrix = var_attention_matrix[layer, head].numpy().copy()
    stats = get_stats(matrix)
    stats_str = format_stats(stats)

    # Mask upper triangle
    if mask_upper:
        mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
        matrix[mask] = np.nan

    fig = px.imshow(
        matrix,
        color_continuous_scale="Viridis",
        title=f"Attention Variance Matrix (Layer {layer}, Head {head})<br><sup>{stats_str}</sup>",
        labels={"x": "Key Position", "y": "Query Position", "color": "Variance"},
    )
    fig.update_layout(width=800, height=850, margin={"t": 120})
    current_var_state = layer, head, mask_upper
    fig.show()


interact(
    plot_variance,
    layer=widgets.IntSlider(min=0, max=n_layers - 1, step=1, value=0),
    head=widgets.IntSlider(min=0, max=n_heads - 1, step=1, value=0),
    mask_upper=widgets.Checkbox(value=True, description="Mask Upper Triangle"),
)


def export_variance_png(_b: widgets.Button) -> None:
    """Export the current variance attention matrix plot as a PNG file."""
    layer, head, mask_upper = current_var_state
    matrix = var_attention_matrix[layer, head].numpy().copy()
    if mask_upper:
        mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
        matrix[mask] = np.nan

    plt.style.use(["bright", "no-latex"])
    plt.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["DejaVu Sans", "Liberation Sans", "Arial"],
            "mathtext.fontset": "dejavusans",
            "font.size": 25,
            "axes.labelsize": 25,
            "xtick.labelsize": 20,
            "ytick.labelsize": 20,
        },
    )
    file_name = (
        f"attention_heatmap_variance_{model_name}_{sequence_length}_{layer}_{head}.png"
    )
    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(matrix, cmap="viridis")
    ax.set_xlabel("Key Position")
    ax.set_ylabel("Query Position")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="4%", pad=0.1)
    fig.colorbar(im, cax=cax, label="Variance")
    fig.tight_layout()
    fig.savefig(result_dir / file_name, dpi=400, bbox_inches="tight")
    plt.close(fig)
    print(f"Exported to {result_dir / file_name}")


button_png = widgets.Button(description="Export PNG")
button_png.on_click(export_variance_png)

display(widgets.HBox([button_png]))