In [None]:
import torch as t
from pizza_clock.dataset import AdditionDataset
from pizza_clock.training import ModularAdditionModelTrainer
from torch.utils.data import DataLoader, random_split
from torch import Tensor, nn
from jaxtyping import Float
import einops
import os
import json
import pandas as pd
from collections import defaultdict, namedtuple
import wandb
from pizza_clock.metrics import compute_gradient_symmetry



In [None]:
for path in [
    "saved_models/p59_attn0.0_td0.8_wd2.0_seed0.pt",
    "saved_models/p59_attn0.0_td0.8_wd2.0_seed4.pt",
    "saved_models/p59_attn0.0_td0.8_wd2.0_seed1.pt",
    "saved_models/p59_attn0.0_td0.8_wd2.0_seed3.pt",
    "saved_models/p59_attn0.0_td0.8_wd2.0_seed2.pt",
    "saved_models/p59_attn1.0_td0.8_wd2.0_seed3.pt",
    "saved_models/p59_attn1.0_td0.8_wd2.0_seed4.pt",
    "saved_models/p59_attn1.0_td0.8_wd2.0_seed2.pt",
    "saved_models/p59_attn1.0_td0.8_wd2.0_seed0.pt",
    "saved_models/p59_attn1.0_td0.8_wd2.0_seed1.pt",
]:
    model = t.load(path, weights_only=False)
    print(f"Gradient similarity for {path}: {compute_gradient_symmetry(model)}")

Gradient similarity for saved_models/p59_attn0.0_td0.8_wd2.0_seed0.pt: 0.9943184435367585
Gradient similarity for saved_models/p59_attn0.0_td0.8_wd2.0_seed4.pt: 0.9948616594076156
Gradient similarity for saved_models/p59_attn0.0_td0.8_wd2.0_seed1.pt: 0.9923302489519119
Gradient similarity for saved_models/p59_attn0.0_td0.8_wd2.0_seed3.pt: 0.9929596620798111
Gradient similarity for saved_models/p59_attn0.0_td0.8_wd2.0_seed2.pt: 0.9935478860139847
Gradient similarity for saved_models/p59_attn1.0_td0.8_wd2.0_seed3.pt: 0.4556414250656962
Gradient similarity for saved_models/p59_attn1.0_td0.8_wd2.0_seed4.pt: 0.6980805608257651
Gradient similarity for saved_models/p59_attn1.0_td0.8_wd2.0_seed2.pt: 0.6094339151680469
Gradient similarity for saved_models/p59_attn1.0_td0.8_wd2.0_seed0.pt: 0.40657910326495766
Gradient similarity for saved_models/p59_attn1.0_td0.8_wd2.0_seed1.pt: 0.46988426925614474


In [3]:
def parse_name(name: str):
    # Name looks like p113_attn1.0_td0.8_wd2.0_seed4.
    parts = name.split("_")
    params = {}
    for part in parts:
        if part.startswith("p"):
            params["p"] = int(part[1:])
        elif part.startswith("attn"):
            params["attention_rate"] = float(part[4:])
        elif part.startswith("td"):
            params["train_data_fraction"] = float(part[3:])
        elif part.startswith("wd"):
            params["weight_decay"] = float(part[2:])
        elif part.startswith("seed"):
            params["seed"] = int(part[4:])
    return params

names_and_ids = pd.read_csv("wandb_export_2026-01-22T16_23_22.764+01_00.csv")
runs = []
for row in names_and_ids.iterrows():
    name, id = row[1]
    params = parse_name(name)
    runs.append({"name": name, "id": id, **params})
runs_df = pd.DataFrame(runs)

In [4]:
Params = namedtuple("Params", ["p", "attention_rate", "train_data_fraction", "weight_decay"])
api = wandb.Api()

[34m[1mwandb[0m: [wandb.Api()] Loaded credentials for https://api.wandb.ai from /Users/hanna/.netrc.


In [None]:
def update_run_configs(runs_df: pd.DataFrame):
    for _, row in runs_df.iterrows():
        # run is specified by <entity>/<project>/<run_id>
        run = api.run(f"gahanna999-/modular-addition-attention-grokking-sweep/{row['id']}")
        run.config["p"] = row["p"]
        run.config["attention_rate"] = row["attention_rate"]
        run.config["train_data_fraction"] = row["train_data_fraction"]
        run.config["weight_decay"] = row["weight_decay"]
        run.update()
        params = Params(
            p=row["p"],
            attention_rate=row["attention_rate"],
            train_data_fraction=row["train_data_fraction"],
            weight_decay=row["weight_decay"],
        )
update_run_configs(runs_df)