In [None]:
!pip install numpy torch sympy mod blobfile pandas seaborn matplotlib tqdm einops wandb

import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from contextlib import suppress
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Callable, Literal, Optional, Union, Tuple, List
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
import wandb
import os
import ipywidgets as widgets
import wandb

if os.getenv('USE_TQDM_NOTEBOOK', 'NO').lower() in ['yes', 'true', '1']:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

import matplotlib as mpl
from matplotlib.colors import LogNorm
import seaborn as sns
import matplotlib.pyplot as plt

from patterns.dataset import ModularArithmetic, Operator
from patterns.transformer import Transformer
from patterns.utils import generate_run_name
from patterns.learner import Config

from toy_models.fit import rescale_run, Pattern, PatternLearningModel

from unifying.sweep import get_history, handle_outliers
from unifying.plotting import BLUE, RED

DEFAULT_MODULUS = 113
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sns.set_theme(style="darkgrid")
ENTITY = "<Insert entity here>"

In [None]:
def plot_patterns(pl_model, run, log=False):
    ts = run["_step"].values
    train_preds = [pl_model(t).detach().numpy() for t in ts]
    test_preds = [pl_model.test(t).detach().numpy() for t in ts]
    train_ys = torch.tensor(run["train/acc"].values).float()
    test_ys = torch.tensor(run["test/acc"].values).float()
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 5))

    axes[0].plot(ts, train_preds, label="train", color="blue")
    axes[0].plot(ts, test_preds, label="test", color="red")

    axes[1].plot(ts, train_ys, label="train", color="blue")
    axes[1].plot(ts, test_ys, label="test", color="red")

    axes[0].set_title("Predictions")
    axes[1].set_title("True values")

    if log:
        axes[0].set_xscale("log")
        axes[1].set_xscale("log")

In [None]:
api = wandb.Api()

run = api.run(f"{ENTITY}/grokking/newag4a9")
history = run.history()
history

In [None]:
history.sort_values(by="_step", inplace=True)
history.plot(x="_step", y=["train/acc", "test/acc"], logx=True)

In [None]:
rescaled_run = rescale_run(history, new_max=100., log=False)
model = PatternLearningModel(num_patterns=2, max_time=100.)

# Initialization
for i, pattern in enumerate(model.patterns):
    max_time = 100
    pattern.onset.data = torch.tensor(0.1 * (0.25 * max_time) ** i)
    pattern.speed.data = torch.tensor((max_time / 2) * 10 ** (-i))
    # pattern._strength.data = pattern._inv_sigmoid(torch.tensor([.8, 1.0, 1.0][i]))
    # pattern._generalization.data = torch.log(torch.tensor([.3, 0.01, .69][i]))


def callback(x): 
    plot_patterns(x, rescaled_run, log=True)
    plt.show()

callback(model)

model.fit(rescaled_run, lr=0.1, num_epochs=500, callback=callback, callback_ivl=100)

In [None]:
# Because of the way generalization & strength are parametrized, they can never reach 0.0 or 1.0, respectively. 
# Instead, we specify a threshold for the strength and generalization parameters, and use that to 
# determine whether a pattern is active or not, after training
eps = 0.03

for p in model.patterns:
    if p.strength > 1 - eps:
        p._strength.data *= torch.inf
        
    if p.generalization < eps:
        p._generalization.data += -torch.inf


callback(model)

In [None]:
# Pattern(strength=0.9986720085144043, speed=17.239028930664062, onset=0.2342032641172409, generalization=0.027139823883771896)
# model.rescale(100)
# Save grokking fit (it's not a torch model, so we can't use torch.save)

import pickle

with open("grokking_fit.pkl", "wb") as f:
    pickle.dump(model, f)

model.patterns

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))

run = history

ax.plot(run["_step"], run["test/acc"], label="Test (Truth)", color=RED, linewidth=2, alpha=0.5)
ax.plot(run["_step"], run["train/acc"], label="Train (Truth)", color=BLUE, linewidth=2, alpha=0.5)
ax.set_ylabel("Accuracy", fontsize=18)
# ax.set_xlabel("Steps", fontsize=18)
# ax.set_xticklabels(["", "", "", "", "", ""], color="white")

min_, max_ = 0, 100
ts = torch.linspace(min_, max_, 10000)
TS = ts * 80_000 / 100
train_ys = [model(t).detach().numpy() for t in ts]
test_ys = [model.test(t).detach().numpy() for t in ts]
ax.plot(TS, train_ys, label="Train (Fit)", color=BLUE, linewidth=2, linestyle="--")
ax.plot(TS, test_ys, label="Test (Fit)", color=RED, linewidth=2, linestyle="--")
# ax.set_title("Fit", )
ax.legend(fontsize=14)

ax.set_xticklabels(["$10^0$", "$10^1$", "$10^2$", "$10^3$", "$10^4$", "$10^5$"])
# ax.set_xlim(10, 100)
ax.set_xscale("log")
# ax.set_xticks([0.01, 0.1, 1, 10, 100])

fig.tight_layout(pad=0.25)
ax.set_xlabel("Steps", fontsize=18)


# Already in log scale
# train_ys, test_ys

plt.savefig("../figures/grokking-fit.pdf", bbox_inches="tight")

In [None]:
def df_row_to_toy_model(row):
    model = PatternLearningModel(max_time=100.)

    for i, pattern in enumerate(model.patterns):
        pattern.onset.data = torch.tensor(row[f"pattern_{i}/onset"])
        pattern.speed.data = torch.tensor(row[f"pattern_{i}/speed"])
        pattern.strength.data = torch.tensor(row[f"pattern_{i}/strength"])
        pattern.generalization.data = torch.tensor(row[f"pattern_{i}/generalization"])

    return model

model = df_row_to_toy_model(model_entry)
model.rescale(100)
model


In [None]:
# Get corresponding original run
og_df = get_history(DM_SWEEP_ID, unique_cols="d_model")
run = og_df.loc[og_df.d_model==D_MODEL,:] #.plot(x="_step", y="test/acc")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 4))

ax1.plot(run["_step"], run["test/acc"], label="Test", color=RED, linewidth=2)
ax1.plot(run["_step"], run["train/acc"], label="Train", color=BLUE, linewidth=2)
ax1.set_ylabel("Accuracy", fontsize=18)
ax1.set_xlabel("Step", fontsize=18)
ax1.set_xscale("log")
ax1.legend()
ax1.set_title("Truth")

min_step, max_step = og_df["_step"].min(), 10000 # run["_step"].max()

ts = np.linspace(min_step, max_step, 1000)
train_ys = [model(t).detach().numpy() for t in ts]
test_ys = [model.test(t).detach().numpy() for t in ts]
ax2.plot(ts, train_ys, label="Train", color=BLUE, linewidth=2)
ax2.plot(ts, test_ys, label="Test", color=RED, linewidth=2)
ax2.set_ylabel("Accuracy", fontsize=18)
ax2.set_xlabel("Step", fontsize=18)
ax2.set_title("Fit")

# Already in log scale
# train_ys, test_ys

In [None]:
# Ignore any d_model < 50
df_cleaned = df.loc[df["d_model"] >= 50, :]
d_models = df_cleaned.loc[:, "d_model"].unique()

# Scaling analysis
fig = plt.figure(figsize=(15, 4))
ax = fig.add_subplot(111)

colors = [BLUE, RED, "green"]
y_max = 0

for i in range(3):
    slice = df_cleaned.loc[:, f"pattern_{i}/onset"]
    y_max = max(y_max, slice.max())
    ax.plot(d_models, slice, label=f"", color=colors[i], linewidth=2)

ax.set_xlabel("d_model", fontsize=18)
ax.set_ylabel("Onset", fontsize=18)


# Fit a power-law to the onsets 
from scipy.optimize import curve_fit

def power_law(x, a, b):
    return a * x**b

def fit_power_law(x, y):
    popt, pcov = curve_fit(power_law, x, y)
    return popt


CUTOFF = 175

# Fit power law to onset
for i in range(3):
    # Train up to a specific point
    df_to_fit = df_cleaned.loc[df_cleaned["d_model"] <= CUTOFF, :]
    d_models_to_fit = df_to_fit.loc[:, "d_model"].unique()

    onset_popt = fit_power_law(d_models_to_fit, df_to_fit.loc[:, f"pattern_{i}/onset"])
    exponent = round(onset_popt[1], 2)
    ax.plot(d_models, power_law(d_models, *onset_popt), label=f"$\\nu_{i} = {exponent}$", color=colors[i], linestyle="--", linewidth=2)

ax.vlines(CUTOFF, 0, y_max * 1.05, color="grey", linestyle="--", linewidth=2)
ax.set_xlabel("Embedding dim.", fontsize=18)
ax.set_ylim(0, y_max * 1.05)

ax.legend()

In [None]:
# Let's see if we can fit 