In [1]:
import json
import os
import typing
from collections import defaultdict, namedtuple
from typing import Type

import einops
import numpy as np
import pandas as pd
import torch as t
from devinterp.optim.sgld import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import evaluate_ce
from devinterp.vis_utils import EpsilonBetaAnalyzer
from jaxtyping import Float
from torch import Tensor, nn
from torch.utils.data import DataLoader, random_split

import wandb
from pizza_clock.config import get_device
from pizza_clock.dataset import AdditionDataset
from pizza_clock.metrics import compute_gradient_similarity
from pizza_clock.training import ModularAdditionModelTrainer
from pathlib import Path
from pizza_clock.dataset import get_train_val_data
from pizza_clock.config import Config
from functools import partial
from torch.nn import functional as F




In [2]:
def evaluate_last_position(criterion, model, data):
    x, y = data
    out = model(x)
    logits = out[:, -1, :]  # Get the last position's logits: [batch, vocab]
    return criterion(logits, y), {"output": logits}

evaluate_last_position_ce = partial(evaluate_last_position, F.cross_entropy)


def estimate_llc_given_model(
    model: t.nn.Module,
    loader: t.utils.data.DataLoader,
    evaluate: typing.Callable,
    epsilon: float,
    beta: float,
    sampling_method: Type[t.optim.Optimizer] = SGLD,
    localization: float = 5.0,
    num_chains: int = 2,
    num_draws: int = 500,
    num_burnin_steps: int = 0,
    num_steps_bw_draws: int = 1,
    device: t.device = get_device(),
    online: bool = True,
    verbose: bool = False,
):
    # Copied from devinterp grokking notebook https://github.com/timaeus-research/devinterp/blob/main/examples/grokking.ipynb
    sweep_stats = estimate_learning_coeff_with_summary(
        model,
        loader=loader,
        evaluate=evaluate,
        sampling_method=sampling_method,
        optimizer_kwargs=dict(lr=epsilon, localization=localization, nbeta=beta),
        num_chains=num_chains,  # How many independent chains to run
        num_draws=num_draws,  # How many samples to draw per chain
        num_burnin_steps=num_burnin_steps,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=num_steps_bw_draws,  # How many steps to take between each sample
        device=device,
        online=online,
        verbose=verbose,
    )

    sweep_stats["llc/trace"] = np.array(sweep_stats["llc/trace"])
    return sweep_stats

In [3]:
def load_model_and_config(dir_path: str):
    config_json = json.load(open(f"{dir_path}/config.json", "r"))
    config = Config(**config_json)
    final_model = t.load(f"{dir_path}/final_model.pt", map_location=get_device(), weights_only=False)
    all_models = [
        t.load(f"{dir_path}/model_{i}.pt", map_location=get_device(), weights_only=False)
        for i in range(len(list(Path(dir_path).glob("model_*.pt"))))
    ]
    return final_model, config, all_models

In [4]:
dir_path = "saved_models/2026-01-27/attn0.0_seed4"
final_model, config, all_models = load_model_and_config(dir_path)
train_loader, _ = get_train_val_data(config, squeeze_targets=True)

analyzer = EpsilonBetaAnalyzer()
analyzer.configure_sweep(
    llc_estimator=estimate_llc_given_model,
    llc_estimator_kwargs=dict(
        model=final_model,
        evaluate=evaluate_last_position_ce,  # Use custom evaluate function
        device=get_device(),
        loader=train_loader,
    ),
    min_epsilon=3e-5,
    max_epsilon=3e-1,
    epsilon_samples=5,
    min_beta=None,
    max_beta=None,
    beta_samples=5,
    dataloader=train_loader,
)
analyzer.sweep()

100%|██████████| 25/25 [01:11<00:00,  2.88s/it]
100%|██████████| 25/25 [01:11<00:00,  2.88s/it]


In [5]:
analyzer.plot()

In [7]:
learning_coeff_stats = estimate_learning_coeff_with_summary(
    final_model,
    loader=train_loader,
    evaluate=evaluate_last_position_ce,
    sampling_method=SGLD,
    optimizer_kwargs=dict(lr=0.03, nbeta=2.0, localization=5.0),
    num_chains=3,
    num_draws=1500,
    device=get_device(),
    online=True,
)
trace = learning_coeff_stats["loss/trace"]

Chain 0:   0%|          | 6/1500 [00:00<00:41, 36.31it/s]
Chain 0:   0%|          | 6/1500 [00:00<00:41, 36.31it/s]


RuntimeError: NaN detected in loss at chain 0, draw 6