In [1]:
import torch as t
from pizza_clock.dataset import AdditionDataset
from pizza_clock.training import ModularAdditionModelTrainer
from torch.utils.data import DataLoader, random_split
from torch import Tensor, nn
from jaxtyping import Float
import einops
import os
import json
import pandas as pd
from collections import defaultdict, namedtuple
import wandb


In [2]:
def parse_name(name: str):
    # Name looks like p113_attn1.0_td0.8_wd2.0_seed4.
    parts = name.split("_")
    params = {}
    for part in parts:
        if part.startswith("p"):
            params["p"] = int(part[1:])
        elif part.startswith("attn"):
            params["attention_rate"] = float(part[4:])
        elif part.startswith("td"):
            params["train_data_fraction"] = float(part[3:])
        elif part.startswith("wd"):
            params["weight_decay"] = float(part[3:])
        elif part.startswith("seed"):
            params["seed"] = int(part[4:])
    return params

names_and_ids = pd.read_csv("wandb_export_2026-01-22T16_23_22.764+01_00.csv")
runs = []
for row in names_and_ids.iterrows():
    name, id = row[1]
    params = parse_name(name)
    runs.append({"name": name, "id": id, **params})
runs_df = pd.DataFrame(runs)


In [3]:
Params = namedtuple("Params", ["p", "attention_rate", "train_data_fraction", "weight_decay"])
api = wandb.Api()


[34m[1mwandb[0m: [wandb.Api()] Loaded credentials for https://api.wandb.ai from /Users/hanna/.netrc.


In [None]:
# Plot runs based on parameters (excluding seed) and average the metrics. One plot should include all runs for a given parameter setup.
train_loss_histories = defaultdict(list)
val_loss_histories = defaultdict(list)
val_accuracy_histories = defaultdict(list)
for row in runs_df.iterrows():
    # run is specified by <entity>/<project>/<run_id>
    run = api.run(f"gahanna999-/modular-addition-attention-grokking-sweep/{row[1]['id']}")
    # run.config["p"] = row[1]["p"]
    # run.config["attention_rate"] = row[1]["attention_rate"]
    # run.config["train_data_fraction"] = row[1]["train_data_fraction"]
    # run.config["weight_decay"] = row[1]["weight_decay"]
    # run.update()
    params = Params(
        p=row[1]["p"],
        attention_rate=row[1]["attention_rate"],
        train_data_fraction=row[1]["train_data_fraction"],
        weight_decay=row[1]["weight_decay"],
    )
    history = run.scan_history(keys=["train loss", "val loss", "val accuracy"])
    train_loss_history = [row["train loss"] for row in history]
    val_loss_history = [row["val loss"] for row in history]
    val_accuracy_history = [row["val accuracy"] for row in history]
    train_loss_histories[params].append(train_loss_history)
    val_loss_histories[params].append(val_loss_history)
    val_accuracy_histories[params].append(val_accuracy_history)

<wandb.apis.public.history.SampledHistoryScan object at 0x14548a660>
<wandb.apis.public.history.SampledHistoryScan object at 0x1515a9590>
<wandb.apis.public.history.SampledHistoryScan object at 0x1515a9810>
<wandb.apis.public.history.SampledHistoryScan object at 0x1432e2060>
<wandb.apis.public.history.SampledHistoryScan object at 0x1432e1a70>
<wandb.apis.public.history.SampledHistoryScan object at 0x1508c9010>
<wandb.apis.public.history.SampledHistoryScan object at 0x150bb4f30>
<wandb.apis.public.history.SampledHistoryScan object at 0x150bb4e20>
<wandb.apis.public.history.SampledHistoryScan object at 0x150f27450>
<wandb.apis.public.history.SampledHistoryScan object at 0x150f27350>
<wandb.apis.public.history.SampledHistoryScan object at 0x117fb6e40>
<wandb.apis.public.history.SampledHistoryScan object at 0x1518ce4e0>
<wandb.apis.public.history.SampledHistoryScan object at 0x1508cdc50>
<wandb.apis.public.history.SampledHistoryScan object at 0x1508cec10>
<wandb.apis.public.history.Sampled