# Lambda Schedule

In [None]:
import pandas as pd
import wandb

In [None]:
from pathlib import Path

result_dir = Path("../../results/wandbs")
result_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Retrieve metrics from summary json dict for each run in a project

from typing import Any

import numpy as np

api = wandb.Api()

ENTITY = "your-wandb-entity"  # Replace with your W&B username or team name
PROJECT = "position-bias"

project_path = f"{ENTITY}/{PROJECT}"

print(f"Fetching runs from project: {project_path}")
runs = api.runs(path=project_path)

metrics_list = []

for run_ind, run in enumerate(runs):
    print(f"Processing run {run_ind}: {run.id}")
    if run.state == "finished":
        print(f"Run {run.id} is finished.")
    else:
        print(f"Run {run.id} is not finished. Skipping.")
        continue

    print("Fetching metrics...")
    metrics_names = [
        "final/attn_mean",
        "final/attn_stddev",
        "final/res_mean",
        "final/res_stddev",
        "final/layer_idx",
    ]
    metrics_df = run.history(keys=metrics_names, x_axis="final/layer_idx", pandas=True)
    metrics_df = metrics_df.rename(columns=lambda x: x.replace("final/", ""))

    metrics_df["run_id"] = run.id
    metrics_df["run_name"] = run.name

    def unfold_dict(
        dictionary: dict[str, Any],
        parent_key: str = "",
        sep: str = "_",
    ) -> dict[str, Any]:
        """Unfold nested dictionary into a flat dictionary."""
        items = {}
        for key, value in dictionary.items():
            new_key = f"{parent_key}{sep}{key}" if parent_key else key
            if isinstance(value, (str, int, bool, float)):
                items[new_key] = value
            elif isinstance(value, list):
                if len(value) == 1:
                    items[new_key] = value[0]
                else:
                    items[new_key] = np.nan
            elif isinstance(value, dict):
                items.update(unfold_dict(value, new_key))
            elif value is None:
                items[new_key] = np.nan
            else:
                raise ValueError(f"Unsupported type: {type(value)}")
        return items

    run_config = unfold_dict(run.config)
    for key, value in run_config.items():
        metrics_df[key] = value

    metrics_list.append(metrics_df)
    print(f"Successfully retrieved metrics for run {run.id}.")

combined_metrics = pd.concat(metrics_list, ignore_index=True)
print("Successfully combined all metrics into a single DataFrame.")

In [None]:
combined_metrics.keys()

In [None]:
print(combined_metrics.shape)

In [None]:
combined_metrics["hf_model_config_model_id"].unique()

In [None]:
clean_df = combined_metrics[
    [
        "run_id",
        "dataset_config_repo_id",
        "hf_model_config_model_id",
        "num_batches",
        "batch_size",
        "sequence_length",
        "attn_mean",
        "attn_stddev",
        "res_mean",
        "res_stddev",
        "layer_idx",
    ]
]
clean_df.loc[:, "hf_model_config_model_id"] = (
    clean_df["hf_model_config_model_id"].str.split("/").str[-1].str.lower()
)
# rename some models for consistency
clean_df.loc[
    clean_df["hf_model_config_model_id"] == "bloom-7b1",
    "hf_model_config_model_id",
] = "bloom-7b"
clean_df.loc[
    clean_df["hf_model_config_model_id"] == "bloom",
    "hf_model_config_model_id",
] = "bloom-176b"
clean_df.loc[
    clean_df["hf_model_config_model_id"] == "mpt-30b-peft-compatible",
    "hf_model_config_model_id",
] = "mpt-30b"
clean_df

In [None]:
clean_df["hf_model_config_model_id"].unique()

In [None]:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import scienceplots  # noqa: F401


def style_attn_plot(ax: Axes) -> None:
    """Style attention plot axes."""
    plt.style.use(["science", "bright", "grid", "no-latex"])
    # Use sans-serif fonts (incl. mathtext) for cleaner small-scale rendering.
    plt.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["DejaVu Sans", "Liberation Sans", "Arial"],
            "mathtext.fontset": "dejavusans",
        },
    )
    ax.set_xlabel("Layer", fontsize=10)
    ax.set_ylabel(ylabel, fontsize=10)
    ax.tick_params(axis="both", which="major", labelsize=9)
    ax.set_ylim(0, 1)
    ax.legend(fontsize=9, frameon=False, handlelength=1.2, columnspacing=0.8)
    ax.margins(x=0.02)
    ax.grid(visible=True, which="major", alpha=0.2)


def plot_attn_mean_by_layer(
    clean_df: pd.DataFrame,
    selected_models: list[str],
    dataset_name: str = "HuggingFaceFW/fineweb-edu",
    sequence_length: int = 2048,
) -> tuple[Figure, Axes, pd.DataFrame]:
    """Plot attention mean by layer for selected models."""
    selected_models = [m.split("/")[-1].lower() for m in selected_models]
    plot_df = clean_df[
        clean_df["hf_model_config_model_id"].isin(selected_models)
    ].copy()
    plot_df = plot_df[plot_df["sequence_length"] == sequence_length]
    plot_df = plot_df[plot_df["dataset_config_repo_id"] == dataset_name]
    plot_df = plot_df.sort_values(["hf_model_config_model_id", "layer_idx"])
    # lower case hf_model_config_model_id for consistency
    plot_df["hf_model_config_model_id"] = plot_df[
        "hf_model_config_model_id"
    ].str.lower()

    fig, ax = plt.subplots(figsize=(4.5, 2), dpi=400)

    grouped = plot_df.groupby("hf_model_config_model_id", sort=False)
    for model_id in selected_models:
        if model_id in grouped.groups:
            group = grouped.get_group(model_id)
            x = group["layer_idx"].to_numpy()
            y = group["attn_mean"].to_numpy()
            ci = (2 * group["attn_stddev"]).to_numpy()

            ax.plot(x, y, label=model_id, linewidth=1.5)
            ax.fill_between(
                x,
                y - ci,
                y + ci,
                alpha=0.2,
                linewidth=0,
            )

    style_attn_plot(ax)
    fig.tight_layout()
    return fig, ax, plot_df


def save_figure(fig: Figure, filename: str) -> None:
    """Save figure to paper folder as PDF."""
    fig.savefig(
        result_dir / filename,
        format="pdf",
        dpi=400,
        bbox_inches="tight",
    )
    print(f"Figure saved to {result_dir / filename}")

In [None]:
clean_df["dataset_config_repo_id"].unique()

In [None]:
# Select a few models for the plot (edit this list as needed).
selected_models = [
    "falcon-rw-7b",
    "mpt-7b",
    "mpt-30b",
    "bloom-7b",
    "bloom-176b",
]

# dataset_name = "HuggingFaceFW/fineweb-edu"
# dataset_name = "mlfoundations/dclm-baseline-1.0"
dataset_name = "wikimedia/wikipedia"

sequence_length = 512

fig, ax, plot_df = plot_attn_mean_by_layer(
    clean_df,
    selected_models,
    dataset_name=dataset_name,
    sequence_length=sequence_length,
)
display(plot_df)
plt.show()

figure_name = f"lambda_t_alibi_{dataset_name.replace('/', '_')}_{sequence_length}.pdf"
save_figure(fig, figure_name)