# Attention Monotone

In [None]:
from dataclasses import dataclass
from pathlib import Path

import einops
from jaxtyping import Float
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import numpy as np
import scienceplots  # noqa: F401
import torch

In [None]:
data_path = "/data/shared_data/position-bias/results-final-fineweb-edu"
model_names = [
    "anas-awadalla_mpt-7b",
    "bigscience_bloom",
    "bigscience_bloom-7b1",
    "tiiuae_falcon-rw-7b",
    "eluzhnica_mpt-30b-peft-compatible",
]

sequence_length = 256
show_std_error = False

result_dir = Path("../../results/monotonicity")
result_dir.mkdir(parents=True, exist_ok=True)

In [None]:
@dataclass
class LayerStats:
    """Statistics for a single layer."""

    combinations: int = 0
    violations_count: int = 0
    total_absolute_error: float = 0.0
    total_relative_error: float = 0.0

    violation_rate: float = 0.0
    absolute_error: float = 0.0
    absolute_error_std: float = 0.0
    relative_error: float = 0.0
    relative_error_std: float = 0.0

In [None]:
def compute_alibi_slopes(
    n_heads: int,
) -> Float[torch.Tensor, "n_heads"]:
    """Compute the ALiBi slopes for each attention head."""
    return torch.tensor(
        [pow(2, -8 * (h + 1) / n_heads) for h in range(n_heads)],
        dtype=torch.float64,
    )


def compute_alibi_scores(
    n_layers: int,
    n_heads: int,
) -> Float[torch.Tensor, "n_layers n_heads sequence_length sequence_length"]:
    """Compute the positional-only ALiBi scores."""
    head_alphas = einops.rearrange(compute_alibi_slopes(n_heads), "h -> h 1 1")

    i = einops.rearrange(
        torch.arange(sequence_length, dtype=torch.float64),
        "n -> n 1",
    )
    j = einops.rearrange(
        torch.arange(sequence_length, dtype=torch.float64),
        "n -> 1 n",
    )
    distances = i - j

    scores = -head_alphas * distances

    return einops.repeat(scores, "h i j -> l h i j", l=n_layers)


def load_attention_matrix(
    model_name: str,
) -> Float[torch.Tensor, "n_layers n_heads seq_len seq_len"]:
    """Load the mean attention matrix for a given model."""
    result_dir = Path(data_path) / model_name / str(sequence_length)
    mean_attention_matrix: Float[torch.Tensor, "n_layers n_heads seq_len seq_len"] = (
        torch.load(result_dir / "attention_probs_mean.pt").to(torch.float64)
    )

    n_layers, n_heads, _, _ = mean_attention_matrix.shape

    mean_attention_matrix += compute_alibi_scores(
        n_layers=n_layers,
        n_heads=n_heads,
    )

    return mean_attention_matrix


def compute_softmax_attention(
    attention_matrix: Float[torch.Tensor, "n_layers n_heads seq_len seq_len"],
) -> Float[torch.Tensor, "n_layers seq_len seq_len"]:
    """Apply causal masking and softmax to the attention matrix."""
    masked_attention_matrix = attention_matrix.masked_fill(
        ~torch.ones_like(attention_matrix, dtype=torch.bool).tril(),
        float("-inf"),
    )
    softmax_attention_matrix = torch.nn.functional.softmax(
        masked_attention_matrix,
        dim=-1,
    )

    # We only need monotonicity to hold in expectation across heads
    softmax_attention_matrix = einops.reduce(
        softmax_attention_matrix,
        "l h i j -> l i j",
        "mean",
    )

    return softmax_attention_matrix


def compute_monotonicity_stats(
    attention_matrix: Float[torch.Tensor, "n_layers seq_len seq_len"],
) -> list[LayerStats]:
    """Compute monotonicity violation statistics for each layer."""
    n_layers, seq_len, _ = attention_matrix.shape

    cumsum: Float[torch.Tensor, "n_layers seq_len seq_len"] = attention_matrix.cumsum(
        dim=-1,
    )

    all_layer_stats: list[LayerStats] = []

    i_indices, j_indices = torch.triu_indices(
        seq_len,
        seq_len,
        offset=1,
        device=attention_matrix.device,
    )
    num_pairs = i_indices.shape[0]
    total_combinations = num_pairs * seq_len

    epsilon = 1e-10

    for layer_idx in range(n_layers):
        layer_i = cumsum[layer_idx, i_indices]
        layer_j = cumsum[layer_idx, j_indices]

        diff = layer_i - layer_j
        violating_positions = diff < -epsilon
        violations_count = int(violating_positions.sum().item())

        total_absolute_error = 0.0
        total_relative_error = 0.0
        absolute_error_std = 0.0
        relative_error_std = 0.0

        if violations_count > 0:
            violating_diff = diff[violating_positions]
            absolute_errors = violating_diff.abs()
            total_absolute_error = absolute_errors.sum().item()
            if violations_count > 1:
                absolute_error_std = absolute_errors.std().item()

            val_i = layer_i[violating_positions]
            val_j = layer_j[violating_positions]
            relative_errors = absolute_errors / (torch.max(val_i, val_j) + 1e-5)
            total_relative_error = relative_errors.sum().item()
            if violations_count > 1:
                relative_error_std = relative_errors.std().item()

        layer_stats = LayerStats(
            combinations=total_combinations,
            violations_count=violations_count,
            total_absolute_error=total_absolute_error,
            total_relative_error=total_relative_error,
            violation_rate=violations_count / max(total_combinations, 1),
            absolute_error=total_absolute_error / max(violations_count, 1),
            absolute_error_std=absolute_error_std,
            relative_error=total_relative_error / max(violations_count, 1),
            relative_error_std=relative_error_std,
        )

        all_layer_stats.append(layer_stats)

    return all_layer_stats

In [None]:
all_models_stats = {}
for model_name in model_names:
    print(f"Processing model: {model_name}")
    mean_attention_matrix = load_attention_matrix(model_name)
    softmax_mean_attention_matrix = compute_softmax_attention(mean_attention_matrix)
    all_layer_stats = compute_monotonicity_stats(softmax_mean_attention_matrix)
    all_models_stats[model_name] = all_layer_stats

In [None]:
formatted_model_name = {
    "anas-awadalla_mpt-7b": "mpt-7b",
    "bigscience_bloom": "bloom-176b",
    "bigscience_bloom-7b1": "bloom-7b",
    "tiiuae_falcon-rw-7b": "falcon-rw-7b",
    "eluzhnica_mpt-30b-peft-compatible": "mpt-30b",
}


def style_attn_plot(ax: Axes, ylabel: str) -> None:
    """Style attention plot axes."""
    plt.style.use(["science", "bright", "grid", "no-latex"])
    # Use sans-serif fonts (incl. mathtext) for cleaner small-scale rendering.
    plt.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["DejaVu Sans", "Liberation Sans", "Arial"],
            "mathtext.fontset": "dejavusans",
        },
    )
    ax.set_xlabel("Layer", fontsize=10)
    ax.set_ylabel(ylabel, fontsize=10)
    ax.tick_params(axis="both", which="major", labelsize=9)
    ax.legend(fontsize=9, frameon=False, handlelength=1.2, columnspacing=0.8)
    ax.margins(x=0.02)
    ax.grid(visible=True, which="major", alpha=0.2)


def save_figure(fig: plt.Figure, filename: str) -> None:
    """Save figure to current folder as PDF."""
    fig.savefig(
        result_dir / filename,
        format="pdf",
        dpi=400,
        bbox_inches="tight",
    )
    print(f"Figure saved to {result_dir / filename}")


def plot_layer_stats(all_models_stats: dict[str, list[LayerStats]]) -> None:
    """Plot monotonicity violation statistics across layers for all models."""
    # Violation Rate
    fig, ax = plt.subplots(figsize=(4.5, 3), dpi=400)
    for model_name, stats_list in all_models_stats.items():
        violation_rates = np.array([stats.violation_rate for stats in stats_list])
        layers = np.arange(1, len(stats_list) + 1)
        ax.plot(
            layers,
            violation_rates,
            label=formatted_model_name[model_name],
            linewidth=1.5,
        )

    style_attn_plot(ax, "Violation Rate")
    fig.tight_layout()
    save_figure(fig, "violation_rate.pdf")
    plt.show()

    # Average Absolute Error
    fig, ax = plt.subplots(figsize=(4.5, 3), dpi=400)
    for model_name, stats_list in all_models_stats.items():
        absolute_errors = np.array([stats.absolute_error for stats in stats_list])

        layers = np.arange(1, len(stats_list) + 1)

        ax.plot(
            layers,
            absolute_errors,
            label=formatted_model_name[model_name],
            linewidth=1.5,
        )

        if show_std_error:
            absolute_errors_std = np.array(
                [stats.absolute_error_std for stats in stats_list],
            )
            ci = 2 * absolute_errors_std
            ax.fill_between(
                layers,
                absolute_errors - ci,
                absolute_errors + ci,
                alpha=0.2,
                linewidth=0,
            )

    style_attn_plot(ax, "Average Absolute Error")
    fig.tight_layout()
    save_figure(fig, "average_absolute_error.pdf")
    plt.show()

    # Average Relative Error
    fig, ax = plt.subplots(figsize=(4.5, 3), dpi=400)
    for model_name, stats_list in all_models_stats.items():
        relative_errors = np.array([stats.relative_error for stats in stats_list])

        layers = np.arange(1, len(stats_list) + 1)
        ax.plot(
            layers,
            relative_errors,
            label=formatted_model_name[model_name],
            linewidth=1.5,
        )

        if show_std_error:
            relative_errors_std = np.array(
                [stats.relative_error_std for stats in stats_list],
            )
            ci = 2 * relative_errors_std
            ax.fill_between(
                layers,
                relative_errors - ci,
                relative_errors + ci,
                alpha=0.2,
                linewidth=0,
            )

    style_attn_plot(ax, "Average Relative Error")
    fig.tight_layout()
    save_figure(fig, "average_relative_error.pdf")
    plt.show()


plot_layer_stats(all_models_stats)