# Influence Stats

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.stats import spearmanr, wasserstein_distance
import torch

In [None]:
def build_metrics_tables(
    datasets: dict[str, dict[str, Path]],
    models: dict[str, tuple[str, str]],
    sequence_length: int,
    tensor_name: str = "input_grad_l2_mean.pt",
) -> dict[str, pd.DataFrame]:
    """
    Build one dataframe per dataset with rows=metrics and columns=models.

    Args:
        datasets: mapping like
            {
              "fineweb": {"results_root": Path(...), "dataset_name": "...", "rollout_dir": Path(...)},
              ...
            }
        models: mapping {display_name: (influence_model_dir, rollout_model_dir)}.
        sequence_length: sequence length subfolder name.
        tensor_name: influence tensor filename.

    Returns:
        dict[str, pd.DataFrame] with rows [spearman, wasserstein].

    """
    metric_index = ["spearman", "wasserstein"]
    tables: dict[str, pd.DataFrame] = {}

    for dataset_key, dataset_info in datasets.items():
        results_root = dataset_info["results_root"]
        rollout_root = dataset_info["rollout_dir"]
        df = pd.DataFrame(index=metric_index, columns=models.keys(), dtype=float)

        for display_name, (influence_model_dir, rollout_model_dir) in models.items():
            tensor_path = (
                results_root / influence_model_dir / str(sequence_length) / tensor_name
            )
            influence = torch.load(tensor_path, map_location="cpu")
            influence_arr = (
                influence.detach().cpu().numpy().astype(np.float64).flatten()
            )

            rollout_path = (
                rollout_root / rollout_model_dir / "last_row_distribution.csv"
            )
            rollout_df = pd.read_csv(rollout_path)
            rollout_arr = rollout_df["probability"].to_numpy(dtype=np.float64).flatten()

            # Match lengths by trimming to the shorter one
            min_len = min(len(influence_arr), len(rollout_arr))
            influence_arr = influence_arr[:min_len]
            rollout_arr = rollout_arr[:min_len]

            # Normalize to probability distributions
            influence_arr = influence_arr / influence_arr.sum()
            rollout_arr = rollout_arr / rollout_arr.sum()

            spearman_corr, _ = spearmanr(influence_arr, rollout_arr)

            n = len(influence_arr)
            positions = np.arange(n)
            wd = wasserstein_distance(
                positions,
                positions,
                u_weights=influence_arr,
                v_weights=rollout_arr,
            )
            wd_norm = wd / (n - 1)

            df.loc["spearman", display_name] = round(spearman_corr, 2)
            df.loc["wasserstein", display_name] = round(wd_norm, 4)

        tables[dataset_key] = df
    return tables


# Example usage
datasets = {
    "fineweb": {
        "results_root": Path(
            "/data/shared_data/position-bias/results-final-fineweb-edu",
        ),
        "dataset_name": "HuggingFaceFW/fineweb-edu",
        "rollout_dir": Path("../../results/rollout/all_content/fineweb-edu"),
    },
    "dclm": {
        "results_root": Path("/data/shared_data/position-bias/results-final-dclm"),
        "dataset_name": "mlfoundations/dclm-baseline-1.0",
        "rollout_dir": Path("../../results/rollout/all_content/dclm"),
    },
    "wikipedia": {
        "results_root": Path("/data/shared_data/position-bias/results-final-wikipedia"),
        "dataset_name": "wikimedia/wikipedia",
        "rollout_dir": Path("../../results/rollout/all_content/wikipedia"),
    },
}

models = {
    # display_name: (influence_model_dir, rollout_model_dir)
    "falcon-rw-7b": ("tiiuae_falcon-rw-7b", "falcon-rw-7b"),
    "mpt-7b": ("anas-awadalla_mpt-7b", "mpt-7b"),
    "mpt-30b": ("eluzhnica_mpt-30b-peft-compatible", "mpt-30b"),
    "bloom-7b1": ("bigscience_bloom-7b1", "bloom-7b1"),
    "bloom": ("bigscience_bloom", "bloom"),
}

sequence_length = 256

tables_cd = build_metrics_tables(
    datasets=datasets,
    models=models,
    sequence_length=sequence_length,
)

# Figure c and d
tables_cd["fineweb"]
# tables_cd["dclm"]
# tables_cd["wikipedia"]

In [None]:
datasets = {
    "fineweb": {
        "results_root": Path(
            "/data/shared_data/position-bias/results-final-fineweb-edu",
        ),
        "dataset_name": "HuggingFaceFW/fineweb-edu",
        "rollout_dir": Path("../../results/rollout/no_content"),
    },
    "dclm": {
        "results_root": Path("/data/shared_data/position-bias/results-final-dclm"),
        "dataset_name": "mlfoundations/dclm-baseline-1.0",
        "rollout_dir": Path("../../results/rollout/no_content"),
    },
    "wikipedia": {
        "results_root": Path("/data/shared_data/position-bias/results-final-wikipedia"),
        "dataset_name": "wikimedia/wikipedia",
        "rollout_dir": Path("../../results/rollout/no_content"),
    },
}

models = {
    # display_name: (influence_model_dir, rollout_model_dir)
    "falcon-rw-7b": ("tiiuae_falcon-rw-7b", "falcon-rw-7b"),
    "mpt-7b": ("anas-awadalla_mpt-7b", "mpt-7b"),
    "mpt-30b": ("eluzhnica_mpt-30b-peft-compatible", "mpt-30b"),
    "bloom-7b1": ("bigscience_bloom-7b1", "bloom-7b1"),
    "bloom": ("bigscience_bloom", "bloom"),
}

sequence_length = 256

tables_bd = build_metrics_tables(
    datasets=datasets,
    models=models,
    sequence_length=sequence_length,
)

# Figure b and d
tables_bd["fineweb"]
# tables_bd["dclm"]
# tables_bd["wikipedia"]

In [None]:
datasets = {
    "fineweb": {
        "results_root": Path(
            "/data/shared_data/position-bias/results-final-fineweb-edu",
        ),
        "dataset_name": "HuggingFaceFW/fineweb-edu",
        "rollout_dir": Path("../../results/rollout/no_content"),
    },
    "dclm": {
        "results_root": Path("/data/shared_data/position-bias/results-final-dclm"),
        "dataset_name": "mlfoundations/dclm-baseline-1.0",
        "rollout_dir": Path("../../results/rollout/no_content"),
    },
    "wikipedia": {
        "results_root": Path("/data/shared_data/position-bias/results-final-wikipedia"),
        "dataset_name": "wikimedia/wikipedia",
        "rollout_dir": Path("../../results/rollout/no_content"),
    },
}

models = {
    # display_name: (influence_model_dir, rollout_model_dir)
    "falcon-rw-7b": ("tiiuae_falcon-rw-7b", "falcon-rw-7b-noresidual"),
    "mpt-7b": ("anas-awadalla_mpt-7b", "mpt-7b-noresidual"),
    "mpt-30b": ("eluzhnica_mpt-30b-peft-compatible", "mpt-30b-noresidual"),
    "bloom-7b1": ("bigscience_bloom-7b1", "bloom-7b1-noresidual"),
    "bloom": ("bigscience_bloom", "bloom-noresidual"),
}

sequence_length = 256

tables_ad = build_metrics_tables(
    datasets=datasets,
    models=models,
    sequence_length=sequence_length,
)

# Figure a and d
tables_ad["fineweb"]
# tables_ad["dclm"]
# tables_ad["wikipedia"]

In [None]:
# Build a combined table for each metric (spearman, wasserstein)
for metric in ["spearman", "wasserstein"]:
    rows = []
    for model_name in tables_cd["fineweb"].columns:
        row = {}
        for dataset_key in ["fineweb", "dclm", "wikipedia"]:
            row[(dataset_key, "ad")] = tables_ad[dataset_key].loc[metric, model_name]
            row[(dataset_key, "bd")] = tables_bd[dataset_key].loc[metric, model_name]
            row[(dataset_key, "cd")] = tables_cd[dataset_key].loc[metric, model_name]
        rows.append(row)

    combined = pd.DataFrame(rows, index=tables_cd["fineweb"].columns)
    combined.columns = pd.MultiIndex.from_tuples(
        combined.columns,
        names=["dataset", "table"],
    )
    combined.index.name = "model"

    def _highlight_row_by_dataset(s: pd.Series, metric: str) -> pd.Series:
        styles = pd.Series("", index=s.index)
        for ds in ["fineweb", "dclm", "wikipedia"]:
            block = s.loc[ds]  # ad, bd, cd for one dataset in this model row
            target = block.max() if metric == "spearman" else block.min()
            for table_name in block[block == target].index:
                styles.loc[(ds, table_name)] = "font-weight: bold"
        return styles

    styled = combined.style.apply(
        _highlight_row_by_dataset,
        axis=1,
        metric=metric,
    ).format("{:.2f}")

    print(f"{metric.capitalize()}")
    display(styled)