In [None]:
import torch as t
from pizza_clock.dataset import AdditionDataset
from pizza_clock.training import ModularAdditionModelTrainer
from torch.utils.data import DataLoader, random_split
from torch import Tensor, nn
from jaxtyping import Float
import einops
import os
import json
import pandas as pd
from collections import defaultdict, namedtuple
import wandb
from pizza_clock.metrics import compute_gradient_similarity
from devinterp.optim.sgld import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import evaluate_ce
from devinterp.vis_utils import EpsilonBetaAnalyzer
import typing
from typing import Type
import numpy as np
from pizza_clock.config import get_device

In [None]:
def estimate_llc_given_model(
    model: t.nn.Module,
    loader: t.utils.data.DataLoader,
    evaluate: typing.Callable,
    epsilon: float,
    beta: float,
    sampling_method: Type[t.optim.Optimizer] = SGLD,
    localization: float = 5.0,
    num_chains: int = 2,
    num_draws: int = 500,
    num_burnin_steps: int = 0,
    num_steps_bw_draws: int = 1,
    device: t.device = get_device(),
    online: bool = True,
    verbose: bool = False,
):
    sweep_stats = estimate_learning_coeff_with_summary(
        model,
        loader=loader,
        evaluate=evaluate,
        sampling_method=sampling_method,
        optimizer_kwargs=dict(lr=epsilon, localization=localization, nbeta=beta),
        num_chains=num_chains,  # How many independent chains to run
        num_draws=num_draws,  # How many samples to draw per chain
        num_burnin_steps=num_burnin_steps,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=num_steps_bw_draws,  # How many steps to take between each sample
        device=device,
        online=online,
        verbose=verbose,
    )

    sweep_stats["llc/trace"] = np.array(sweep_stats["llc/trace"])
    return sweep_stats