This notebook aims to investigate how LLC estimation could be used as a training metric using the peak detection task.

This is based off the Timaeus grokking notebook found here: https://github.com/timaeus-research/devinterp/blob/main/examples/grokking.ipynb

In [None]:
%pip install devinterp nbformat
%pip install devinterp[vis]

In [None]:
import random
import math
from copy import deepcopy
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchinfo import summary
from devinterp.optim.sgld import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import evaluate_ce
from llc_training.grokking.peak_models import Nano, Small, Medium_Large

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Make all relevant functions and classes for training
@dataclass
class ExperimentParams:
    p: int = 100
    l: int = 5
    v = 12
    n_batches: int = 1000
    n_save_model_checkpoints: int = 100
    print_times: int = 100
    lr: float = 3e-3
    batch_size: int = 128
    hidden_size: int = 48
    linear_hidden_size: int = 200
    embed_dim: int = 127
    train_frac: float = 0.4
    random_seed: int = 0
    device: str = DEVICE
    weight_decay: float = 2e-5
    blocks: int = 6

def test(model, dataset, device):
    n_correct = 0
    total_loss = 0
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in dataset:
            x = x.tolist()
            x = torch.tensor([x])
            x, y = x.to(device), y.to(device)
            out = model(x)[0]
            loss = loss_fn(out, y)
            total_loss += loss.item()
            for i in range(len(out)):
                if y[i][0] > y[i][1] and out[i][0] > out[i][1]:
                    n_correct += 1
                if y[i][0] < y[i][1] and out[i][0] < out[i][1]:
                    n_correct += 1
    return n_correct / (5 * len(dataset)), total_loss / len(dataset)


def train(train_dataset, test_dataset, params, verbose=True):
    all_models = []
    model = Medium_Large(params).to(params.device)
    optimizer = torch.optim.Adam(
        model.parameters(), weight_decay=params.weight_decay, lr=params.lr
    )
    loss_fn = torch.nn.CrossEntropyLoss()

    train_loader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True)

    print_every = params.n_batches // params.print_times
    checkpoint_every = None
    if params.n_save_model_checkpoints > 0:
        checkpoint_every = params.n_batches // params.n_save_model_checkpoints

    loss_data = []
    if verbose:
        pbar = tqdm(total=params.n_batches, desc="Training")
    for i in range(params.n_batches):
        # Sample random batch of data
        batch = next(iter(train_loader))
        X, Y = batch
        X, Y = X.to(params.device), Y.to(params.device)
        # Gradient update
        optimizer.zero_grad()
        out = model(X)
        loss = loss_fn(out, Y)
        loss.backward()
        optimizer.step()

        if checkpoint_every and (i + 1) % checkpoint_every == 0:
            all_models += [deepcopy(model)]

        if (i + 1) % print_every == 0:
            val_acc, val_loss = test(model, test_dataset, params.device)
            train_acc, train_loss = test(model, train_dataset, params.device)
            loss_data.append(
                {
                    "batch": i + 1,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                }
            )
            if verbose:
                pbar.set_postfix(
                    {
                        "train_loss": f"{train_loss:.4f}",
                        "train_acc": f"{train_acc:.4f}",
                        "val_loss": f"{val_loss:.4f}",
                        "val_acc": f"{val_acc:.4f}",
                    }
                )
                pbar.update(print_every)
    if verbose:
        pbar.close()
    df = pd.DataFrame(loss_data)
    train_acc, train_loss = test(model, train_dataset, params.device)
    val_acc, val_loss = test(model, test_dataset, params.device)
    if verbose:
        print(f"Final Train Acc: {val_acc:.4f} | Final Train Loss: {val_loss:.4f}")
        print(f"Final Val Acc: {val_acc:.4f} | Final Val Loss: {val_loss:.4f}")
    return all_models, df


def deterministic_shuffle(lst, seed):
    random.seed(seed)
    random.shuffle(lst)
    return lst


def make_dataset(p):
    data = []
    vocab = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    seq_len=5
    for _ in range(p):
        seq = [int(np.random.choice(vocab)) for _ in range(seq_len)]

        peaks = []
        for i in range(0, len(seq)):
            if i == 0 or i == len(seq) - 1:
                if len(seq) == 2:
                    peaks.append(False)
                elif i == 0:
                    peaks.append(seq[i] > seq[i + 1])
                elif i == len(seq) - 1:
                    peaks.append(seq[i] > seq[i - 1])
            else:
                peaks.append(seq[i] > seq[i - 1] and seq[i] > seq[i + 1])
        # Label = 1 if there is any peak, else 0
        labels = []

        for val in peaks:
            if val == np.True_:
                labels.append([0.0, 1.0])
            else:
                labels.append([1.0, 0.0])

        data.append((torch.tensor(seq), torch.tensor(labels)))
    return data


def train_test_split(dataset, train_split_proportion, seed):
    l = len(dataset)
    train_len = int(train_split_proportion * l)
    idx = list(range(l))
    idx = deterministic_shuffle(idx, seed)
    train_idx = idx[:train_len]
    test_idx = idx[train_len:]
    return [dataset[i] for i in train_idx], [dataset[i] for i in test_idx]

In [None]:
# Initialize params and get the dataset
params = ExperimentParams()
torch.manual_seed(params.random_seed)

dataset = make_dataset(params.p)
train_data, test_data = train_test_split(dataset, params.train_frac, params.random_seed)

In [None]:
# Prints model parameter count
model = Medium_Large(params).to(params.device)
print(summary(model))

In [None]:
all_checkpointed_models, df = train(
    train_dataset=train_data, test_dataset=test_data, params=params
)

In [None]:
plt.plot(df["val_acc"], label="test")
plt.plot(df["train_acc"], label="train")
plt.legend()
plt.ylabel("Correct answer %")
plt.xlabel("Checkpoint")
plt.title(f"Train & test correct answer % for modular addition with p={params.p}")

In [None]:
plt.plot(df["val_loss"], label="test")
plt.plot(df["train_loss"], label="train")
plt.legend()
plt.ylabel("Loss")
plt.xlabel("Checkpoint")
plt.title(f"Train & test loss for modular addition with p={params.p}")

## LLC estimation hyperparameter tuning

We will perform a sweep across several hyperparameters to try to estimate ones that will give us good LLC estimations.

In [None]:
import typing
from typing import Type

import numpy as np


def estimate_llc_given_model(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    evaluate: typing.Callable,
    epsilon: float,
    beta: float,
    sampling_method: Type[torch.optim.Optimizer] = SGLD,
    localization: float = 5.0,
    num_chains: int = 2,
    num_draws: int = 500,
    num_burnin_steps: int = 0,
    num_steps_bw_draws: int = 1,
    device: torch.device = DEVICE,
    online: bool = True,
    verbose: bool = False,
):
    sweep_stats = estimate_learning_coeff_with_summary(
        model,
        loader=loader,
        evaluate=evaluate,
        sampling_method=sampling_method,
        optimizer_kwargs=dict(lr=epsilon, localization=localization, nbeta=beta),
        num_chains=num_chains,  # How many independent chains to run
        num_draws=num_draws,  # How many samples to draw per chain
        num_burnin_steps=num_burnin_steps,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=num_steps_bw_draws,  # How many steps to take between each sample
        device=device,
        online=online,
        verbose=verbose,
    )

    sweep_stats["llc/trace"] = np.array(sweep_stats["llc/trace"])
    return sweep_stats

In [None]:
from devinterp.vis_utils import EpsilonBetaAnalyzer

loader = DataLoader(train_data, shuffle=True, batch_size=params.batch_size)
analyzer = EpsilonBetaAnalyzer()
analyzer.configure_sweep(
    llc_estimator=estimate_llc_given_model,
    llc_estimator_kwargs=dict(
        model=all_checkpointed_models[-1],
        evaluate=evaluate_ce,
        device=DEVICE,
        loader=loader,
    ),
    min_epsilon=3e-5,
    max_epsilon=3e-1,
    epsilon_samples=5,
    min_beta=None,
    max_beta=None,
    beta_samples=5,
    dataloader=loader,
)
analyzer.sweep()

In [None]:
analyzer.plot()

In [None]:
analyzer.plot(div_out_beta=True)

In [None]:
lr = 3e-3
gamma = 5
nbeta = 2.0
num_draws = 75
num_chains = 2

In [None]:
learning_coeff_stats = estimate_learning_coeff_with_summary(
    all_checkpointed_models[-1],
    loader=DataLoader(train_data, batch_size=params.batch_size, shuffle=True),
    evaluate=evaluate_ce,
    sampling_method=SGLD,
    optimizer_kwargs=dict(lr=0.03, nbeta=2.0, localization=5.0),
    num_chains=3,
    num_draws=1500,
    device=DEVICE,
    online=True,
)
trace = learning_coeff_stats["loss/trace"]

In [None]:
from devinterp.utils import plot_trace

plot_trace(
    trace,
    "Loss",
    x_axis="Step",
    title=f"Loss Trace, avg LLC = {sum(learning_coeff_stats['llc/means']) / len(learning_coeff_stats['llc/means']):.2f}",
    plot_mean=False,
    plot_std=False,
    fig_size=(12, 9),
    true_lc=None,
)

From the loss trace, it seems like we can get away with a very low draw count, around 75 should work.

In [None]:
llcs = [
    estimate_learning_coeff_with_summary(
        model_checkpoint,
        loader=DataLoader(train_data, batch_size=params.batch_size, shuffle=True),
        evaluate=evaluate_ce,
        sampling_method=SGLD,
        optimizer_kwargs=dict(lr=lr, nbeta=nbeta, localization=gamma),
        num_chains=1,
        num_draws=num_draws,
        device=DEVICE,
        online=False,
    )
    for model_checkpoint in all_checkpointed_models
]

In [None]:
def calc_run_avg(arr, d):
    run_avgs = [None for _ in range(d // 2)]

    for i in range(len(arr) - d):
        avg = sum(arr[i:i + d]) / d
        run_avgs.append(avg)
    
    return run_avgs

def calc_delta(arr, scale):
    deltas = [None]

    for i in range(1, len(arr)):
        deltas.append((arr[i] - arr[i - 1]) * scale)
    
    return deltas

In [None]:
fig, ax1 = plt.subplots()
plt.title(
    f"Peak Loss 775934 params"
)
ax2 = ax1.twinx()
ax1.plot(df["val_loss"], label="test loss")
ax1.plot(df["train_loss"], label="train loss")
# ax2.plot(calc_run_avg(delta_val_loss[1:], window_size), color="m", label="Change in val acc")
ax2.plot([llc["llc/mean"] for llc in llcs], color="g", label="Lambdahat")
# ax2.plot(run_avg_llc, color="r", label="Running avg llc")
ax1.set_xlabel("Checkpoint no.")
fig.legend(loc="center right")