# Attention Stats

In [None]:
from pathlib import Path
from typing import Literal

from jaxtyping import Float
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import scienceplots  # noqa: F401
import torch

In [None]:
data_path = "/data/shared_data/position-bias/results-final-wikipedia"
model_name: Literal[
    "anas-awadalla_mpt-7b",
    "bigscience_bloom",
    "bigscience_bloom-7b1",
    "tiiuae_falcon-rw-7b",
    "eluzhnica_mpt-30b-peft-compatible",
] = "bigscience_bloom"

sequence_length = 256

data_dir = Path(data_path) / model_name / str(sequence_length)
mean_attention_matrix: Float[torch.Tensor, "n_layers n_heads seq_len seq_len"] = (
    torch.load(data_dir / "qk_mean.pt")
)

n_layers, n_heads, seq_len, _ = mean_attention_matrix.shape

assert n_layers > 0
assert n_heads > 0
assert seq_len > 0

result_dir = Path("../../results/attention")
result_dir.mkdir(parents=True, exist_ok=True)

log_file = result_dir / f"attention_stats_summary_{model_name}_{sequence_length}.txt"

In [None]:
all_vals = mean_attention_matrix.reshape(-1)
max_samples = 1_000_000
if all_vals.numel() > max_samples:
    all_vals = all_vals[
        torch.randperm(all_vals.numel(), device=all_vals.device)[:max_samples]
    ]
vmin = torch.quantile(all_vals, 0.005)
vmax = torch.quantile(all_vals, 0.995)
bins = torch.linspace(vmin, vmax, steps=200, device=all_vals.device)

In [None]:
def hist_probs(
    values: Float[torch.Tensor, "n_layers n_heads num_values"],
    eps: float = 1e-12,
) -> Float[torch.Tensor, "n_layers n_heads n_bins"]:
    """Compute histogram-based probability distributions over specified bins."""
    n_bins = bins.numel() - 1
    idx = torch.bucketize(values, bins, right=False) - 1
    valid = (idx >= 0) & (idx < n_bins)

    idx = idx.masked_fill(~valid, 0)
    counts = torch.zeros(
        *values.shape[:-1],
        n_bins,
        dtype=torch.float64,
        device=values.device,
    )
    ones = torch.ones_like(idx, dtype=counts.dtype)
    ones = ones.masked_fill(~valid, 0)
    counts.scatter_add_(-1, idx, ones)

    probs = counts + eps
    probs = probs / probs.sum(dim=-1, keepdim=True)
    return probs


def within_similarity(
    p: Float[torch.Tensor, "n_layers n_heads n_bins"],
) -> Float[torch.Tensor, "n_layers n_heads"]:
    """Compute the complement of the normalized Shannon entropy."""
    ent = -(p * torch.log(p)).sum(dim=-1)
    ent_norm = ent / torch.log(
        torch.tensor(p.shape[-1], dtype=torch.float64, device=p.device),
    )
    return 1.0 - ent_norm


def print_stats(name: str, values: Float[torch.Tensor, "n_layers n_heads"]) -> None:
    """Print and log basic statistics of the given tensor values."""
    v = values.detach()
    stats = {
        "min": v.min().item(),
        "max": v.max().item(),
        "mean": v.mean().item(),
        "std": v.std(unbiased=False).item(),
    }
    line = (
        f"{name} | min={stats['min']:.4f} max={stats['max']:.4f} "
        f"mean={stats['mean']:.4f} std={stats['std']:.4f}"
    )
    print(line)
    with log_file.open("a", encoding="utf-8") as f:
        f.write(line + "\n")


def js_dissimilarity(
    p: Float[torch.Tensor, "n_layers n_heads n_bins"],
    q: Float[torch.Tensor, "n_layers n_heads n_bins"],
    eps: float = 1e-12,
) -> Float[torch.Tensor, "n_layers n_heads"]:
    """Compute the Jensen-Shannon dissimilarity between two probability distributions."""
    p = p + eps
    q = q + eps
    p = p / p.sum(dim=-1, keepdim=True)
    q = q / q.sum(dim=-1, keepdim=True)
    m = 0.5 * (p + q)
    kl_pm = (p * torch.log(p / m)).sum(dim=-1)
    kl_qm = (q * torch.log(q / m)).sum(dim=-1)
    js = 0.5 * (kl_pm + kl_qm)
    return js

In [None]:
def plot_heatmap(
    data: Float[torch.Tensor, "n_layers n_heads"] | np.ndarray,
    cbar_label: str,
    file_name: str,
) -> None:
    """Plot heatmap of data and save to file."""
    if torch.is_tensor(data):
        data = data.detach().cpu().numpy()
    plt.style.use(["bright", "no-latex"])
    plt.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["DejaVu Sans", "Liberation Sans", "Arial"],
            "mathtext.fontset": "dejavusans",
            "font.size": 25,
            "axes.labelsize": 25,
            "xtick.labelsize": 20,
            "ytick.labelsize": 20,
        },
    )
    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(data, aspect="equal", origin="lower", vmin=0, vmax=1)
    ax.set_xlabel("Head")
    ax.set_ylabel("Layer")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="4%", pad=0.1)
    fig.colorbar(im, cax=cax, label=cbar_label)
    fig.tight_layout()
    fig.savefig(result_dir / file_name, dpi=400, bbox_inches="tight")
    plt.show()
    print(f"Exported to {result_dir / file_name}")

In [None]:
diag: Float[torch.Tensor, "n_layers n_heads seq_len"] = mean_attention_matrix.diagonal(
    dim1=-2,
    dim2=-1,
)

lower_mask = torch.tril(
    torch.ones(seq_len, seq_len, device=mean_attention_matrix.device),
    diagonal=-1,
).bool()
lower: Float[torch.Tensor, "n_layers n_heads num_lower"] = mean_attention_matrix[
    ...,
    lower_mask,
]

## Within-diagonal similarity

In [None]:
diag_similarity = within_similarity(hist_probs(diag))
print_stats("Within-diagonal similarity", diag_similarity)

plot_heatmap(
    diag_similarity,
    "Similarity",
    f"attention_stats_within_diag_{model_name}_{sequence_length}.png",
)

## Within-lower-triangular off diagonal similarity

In [None]:
lower_similarity = within_similarity(hist_probs(lower))
print_stats("Within-lower similarity", lower_similarity)

plot_heatmap(
    lower_similarity,
    "Similarity",
    f"attention_stats_within_lower_{model_name}_{sequence_length}.png",
)

## JS similarity: diagonal vs lower-triangular off-diagonal

In [None]:
dissimilarity = js_dissimilarity(hist_probs(diag), hist_probs(lower))
js_similarity = 1.0 - dissimilarity
print_stats("JS similarity", js_similarity)

plot_heatmap(
    js_similarity,
    "JS similarity",
    f"attention_stats_js_similarity_{model_name}_{sequence_length}.png",
)

## Mean content

In [None]:
mean_diagonal = diag.mean()
mean_off_diagonal = lower.mean()

print(f"Mean diagonal attention: {mean_diagonal:.4f}")
print(f"Mean off diagonal attention: {mean_off_diagonal:.4f}")
print(f"Difference: {mean_diagonal - mean_off_diagonal:.4f}")

## All content

In [None]:
layer_head_mean_diag = diag.mean(dim=-1)
layer_head_mean_lower = lower.mean(dim=-1)

print(f"Diag: {layer_head_mean_diag.tolist()}")
print(f"Lower: {layer_head_mean_lower.tolist()}")
print(f"Difference: {(layer_head_mean_diag - layer_head_mean_lower).tolist()}")