# Input Token Influence

In [None]:
from typing import Literal

import numpy as np
from numpy.typing import NDArray
import torch

In [None]:
from pathlib import Path

result_dir = Path("../../results/input_token_influence")
result_dir.mkdir(parents=True, exist_ok=True)

In [None]:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, StrMethodFormatter
import scienceplots  # noqa: F401


def style_attn_plot(ax: Axes) -> None:
    """Apply consistent styling to attention plots."""
    plt.style.use(["science", "bright", "grid", "no-latex"])
    # Use sans-serif fonts (incl. mathtext) for cleaner small-scale rendering.
    plt.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["DejaVu Sans", "Liberation Sans", "Arial"],
            "mathtext.fontset": "dejavusans",
        },
    )
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    ax.lines[0].set_color(colors[4])
    ax.set_xlabel("Token position", fontsize=14)
    ax.set_ylabel("$\\hat{p}^{(T)}(j)$", fontsize=14)
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.yaxis.set_major_formatter(StrMethodFormatter("{x:.3f}"))
    ax.tick_params(axis="both", which="major", labelsize=11)
    ax.margins(x=0.02)
    ax.grid(visible=True, which="major", alpha=0.2)


def plot_influence(
    influence: NDArray,
) -> tuple[Figure, Axes]:
    """Plot input gradient influence scores."""
    positions = np.arange(len(influence))

    assert influence.ndim == 1, (
        f"Expected 1D influence tensor, got shape {influence.shape}"
    )
    assert positions.shape[0] == influence.shape[0], (
        "Positions and influence lengths must match"
    )

    fig, ax = plt.subplots(figsize=(4, 3), dpi=400)

    ax.plot(positions, influence, marker="o", linewidth=1.5)

    style_attn_plot(ax)
    fig.tight_layout()
    return fig, ax


def save_figure(fig: Figure, filename: str) -> None:
    """Save figure to paper figures directory."""
    fig.savefig(
        result_dir / filename,
        format="pdf",
        dpi=400,
        bbox_inches="tight",
    )
    print(f"Figure saved to {result_dir / filename}")

In [None]:
results_dir = Path(
    "/data/shared_data/position-bias/results-final",
)  # Update this path to your saved results folder
model: Literal[
    "anas-awadalla_mpt-7b",
    "bigscience_bloom",
    "bigscience_bloom-7b1",
    "tiiuae_falcon-rw-7b",
    "eluzhnica_mpt-30b-peft-compatible",
] = "bigscience_bloom"

dataset_name = "HuggingFaceFW/fineweb-edu"
sequence_length = 256
tensor_path = results_dir / model / str(sequence_length) / "input_grad_l2_mean.pt"

influence = torch.load(tensor_path, map_location="cpu")
influence = influence.detach().cpu().numpy().astype(np.float64)
influence /= influence.sum()

fig, ax = plot_influence(influence)
plt.show()

save_figure(fig, f"iti_{model}_{dataset_name.replace('/', '_')}_{sequence_length}.pdf")