In [None]:
!pip install numpy torch sympy mod blobfile pandas seaborn matplotlib tqdm einops wandb

import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from contextlib import suppress
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Callable, Literal, Optional, Union, Tuple, List
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
import wandb

if os.getenv('USE_TQDM_NOTEBOOK', 'NO').lower() in ['yes', 'true', '1']:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

import ipywidgets as widgets
import wandb

import matplotlib as mpl
from matplotlib.colors import LogNorm
import seaborn as sns
import matplotlib.pyplot as plt

from patterns.dataset import ModularArithmetic, Operator
from patterns.transformer import Transformer
from patterns.utils import generate_run_name
from patterns.learner import Config

from toy_models.fit import rescale_run, Pattern, PatternLearningModel
from unifying.sweep import get_history, handle_outliers
from unifying.plotting import BLUE, RED

DEFAULT_MODULUS = 113
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sns.set_theme(style="darkgrid")

ENTITY = "INSERT ENTITY HERE"

In [None]:
def plot_patterns(pl_model, run, log=False):
    ts = run["_step"].values
    train_preds = [pl_model(t).detach().numpy() for t in ts]
    test_preds = [pl_model.test(t).detach().numpy() for t in ts]
    train_ys = torch.tensor(run["train/acc"].values).float()
    test_ys = torch.tensor(run["test/acc"].values).float()
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 5))

    axes[0].plot(ts, train_preds, label="train", color="blue")
    axes[0].plot(ts, test_preds, label="test", color="red")

    axes[1].plot(ts, train_ys, label="train", color="blue")
    axes[1].plot(ts, test_ys, label="test", color="red")

    axes[0].set_title("Predictions")
    axes[1].set_title("True values")

    if log:
        axes[0].set_xscale("log")
        axes[1].set_xscale("log")

In [None]:
from unifying.sweep import METRICS, get_pivot
from scipy.ndimage import gaussian_filter

def get_mw_sweep(metrics=["test/acc", "train/acc", "test/loss", "train/loss"], steps=None, num_steps=2000):
    df = pd.read_csv("../logs/mw_sweep.csv")

    MINIMUM = 1/97
    missing_row = df.loc[df._step == 60, :].copy()
    missing_row._step = 1
    missing_row["train/acc"] = MINIMUM
    missing_row["test/acc"] = MINIMUM
    df = pd.concat(
        [
            missing_row,
            df,
        ]
    )

    d_models = sorted(df["d_model"].unique())
    print(d_models)
    df = get_pivot(df, "d_model", metrics, reindex=True, interpolate=True) 

    steps = steps or list({int(s) for s in np.logspace(0, np.log10(df.index.max()), num_steps)})
    df = df.loc[df.index.isin(steps), :]

    df = df.reset_index()  # to make _step a regular column
    df.columns = ['_step'] + [f'{x}_{y}' for x, y in df.columns[1:]]  # to make columns single level
    df_melted = df.melt(id_vars='_step', var_name='variable', value_name='value')

    # split the variable column into the original columns
    df_melted[['metric', 'd_model']] = df_melted['variable'].str.split('_', expand=True)
    df_melted['d_model'] = df_melted['d_model'].astype(int)  # convert d_model back to integer

    # pivot to get the original columns
    df_final = df_melted.pivot(index=['_step', 'd_model'], columns='metric', values='value').reset_index()

    return df_final

df = get_mw_sweep(num_steps=500)
df

In [None]:
D_MODEL = 128
run = df.loc[df.d_model == D_MODEL, :]
run

In [None]:
plt.plot(run._step, run["train/acc"], label="train")
# plt.plot(run.index, run["train/acc"], label="smoothed")
plt.plot(run._step, run["test/acc"], label="test")
# plt.plot(run.index, run["test/acc"], label="smoothed")
plt.xscale("log")
plt.legend()
plt.show()

In [None]:
rescaled_run = rescale_run(run, new_max=100., log=False) 

# for metric in METRICS:
#     rescaled_run.loc[metric, :] = gaussian_filter1d(rescaled_run.loc[metric,:])

rescaled_run.plot(x="_step", y=["train/acc", "test/acc"], logx=True, figsize=(10, 5))
rescaled_run

In [None]:
model = PatternLearningModel(max_time=100., num_patterns=3)

# Initialization
for i, pattern in enumerate(model.patterns):
    max_time = 100
    pattern.onset.data = torch.tensor(0.1 * (0.25 * max_time) ** i)
    pattern.speed.data = torch.tensor((max_time / 2) * 10 ** (-i))
    # pattern._strength.data = pattern._inv_sigmoid(torch.tensor([.8, 1.0, 1.0][i]))
    # pattern._generalization.data = torch.log(torch.tensor([.3, 0.01, .69][i]))

print(model.patterns)

def callback(x): 
    plot_patterns(x, rescaled_run, log=True)
    plt.show()

callback(model)

model.fit(rescaled_run, lr=0.1, num_epochs=500, callback=callback, callback_ivl=25)

In [None]:
VARIABLE_COLS = [
    "test/acc",
    "train/acc",
    "test/loss",
    "train/loss",
    "_step",
    "weight/norm",
    "test/efficiency",
    "train/efficiency",
    "weight/dist_from_init",
    "weight/cos_sim_with_init",
]

def fit_sweep(df: pd.DataFrame, unique_col: str, lr=0.1, max_time=1.0, num_patterns=3, num_epochs=500, log=False, **kwargs):
    unique_vals = df.loc[:, unique_col].unique()

    variable_cols = [c for c in df.columns if c in VARIABLE_COLS]
    hyperparams: dict = (
        df.loc[0, :]
        .drop(columns=[unique_col, *variable_cols])
        .to_dict()
    )

    wandb.init(
        project="fit-toy-model",
    )

    try:
        for unique_val in tqdm(unique_vals):
            run = df.loc[df[unique_col] == unique_val]
            rescaled_run = rescale_run(run, new_max=max_time, log=log)

            pl_model = PatternLearningModel(
                num_patterns=num_patterns, 
                max_time=max_time
            )

            def _plot_patterns(pl_model):
                plot_patterns(pl_model, rescaled_run)
                plt.show()

            pl_model.fit(rescaled_run, lr=lr, num_epochs=num_epochs, callback=_plot_patterns)
            pl_model.rescale(1.)

            wandb.log({unique_col: unique_val, **pl_model.to_dict(), **hyperparams, **kwargs})

            _plot_patterns(pl_model)
            plt.show()

    except KeyboardInterrupt:
        wandb.finish()

In [None]:
# Find the first _step for each run (= unique d_model) where the test/acc > 0.5 in df
onsets = df.loc[df["test/acc"] > 0.5, :].groupby("d_model").first()
d_models = onsets.index.values
onsets = onsets["_step"].values
onsets = gaussian_filter(onsets, 2)
onsets_df = pd.DataFrame([{"d_model": d_model, "step": step} for d_model, step in zip(d_models, onsets)])

# Fit a straight line in log-y space to the onsets_df (all samples before d_model =100)
onsets_train = onsets_df.loc[onsets_df.d_model < 120, :]
onsets_train["log_step"] = np.log(onsets_train["step"])

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline

model = make_pipeline(PolynomialFeatures(1), LinearRegression())
model.fit(onsets_train[["d_model"]], onsets_train["log_step"])

# Plot the onsets_df and the fitted line
plt.plot(onsets_df["d_model"], onsets_df["step"], label="Truth")
plt.plot(onsets_df["d_model"], np.exp(model.predict(onsets_df[["d_model"]])), label="Fit")
plt.yscale("log")
plt.vlines(120, 1, 500_000, label="__hidden", linestyle="--")
plt.ylim(1, 500_000)
plt.ylabel("Onset step", fontsize=16)
plt.xlabel("Emedding dim.", fontsize=16)
plt.legend()
# onsets_df

In [None]:
model.

In [None]:
unique_col = "d_model"
unique_vals = df.loc[:, unique_col].unique() 
unique_vals

In [None]:
fit_sweep(df, "d_model", max_time=100., num_patterns=2, num_epochs=500, log=False)

# Fit the sweeps

In [None]:
api = wandb.Api()
runs = api.runs(f"{ENTITY}/fit-toy-model")
[run for run in runs]

In [None]:
run = runs[0]
df = run.history()

for step in df._step.unique():
    df.loc[df._step == step, "d_model"] = unique_vals[step]

df

In [None]:
df.plot(x="d_model", y=["pattern_0/onset", "pattern_1/onset"], figsize=(10, 5))

In [None]:
col = "d_model"
unique_vals = df.loc[:, col].unique()
print(unique_vals)
df

In [None]:
def df_row_to_toy_model(row):
    model = PatternLearningModel(max_time=1.)

    for i, pattern in enumerate(model.patterns):
        pattern.onset.data = torch.tensor(row[f"pattern_{i}/onset"])
        pattern.speed.data = torch.tensor(row[f"pattern_{i}/speed"])
        pattern._strength.data = pattern._inv_sigmoid(torch.tensor(row[f"pattern_{i}/strength"]))  # type: ignore
        pattern._generalization.data = torch.log(torch.tensor(row[f"pattern_{i}/generalization"]))

    return model

D_MODEL = 115
co9l = "d_model"
model_entry = df.loc[df[col] == D_MODEL, :].iloc[0, :]
print(model_entry)
model = df_row_to_toy_model(model_entry)
model.rescale(100)
model.patterns

In [None]:
# Get corresponding original run
og_df = get_history(DM_SWEEP_ID, unique_cols="d_model")
run = og_df.loc[og_df.d_model==D_MODEL,:] #.plot(x="_step", y="test/acc")

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))

ax1.plot(run["_step"], run["test/acc"], label="Test", color=RED, linewidth=2)
ax1.plot(run["_step"], run["train/acc"], label="Train", color=BLUE, linewidth=2)
ax1.set_ylabel("Accuracy", fontsize=18)
# ax1.set_xlabel("Steps", fontsize=18)
ax1.set_xticklabels(["", "", "", "", "", ""], color="white")
ax1.set_xscale("log")
ax1.legend(title="Truth", fontsize=16, title_fontsize=18)

min_step, max_step = og_df["_step"].min(), 100 # run["_step"].max()

ts = np.linspace(12, max_step, 1000)
train_ys = [model(t).detach().numpy() for t in ts]
test_ys = [model.test(t).detach().numpy() for t in ts]
ax2.plot(ts, train_ys, label="Train", color=BLUE, linewidth=2)
ax2.plot(ts, test_ys, label="Test", color=RED, linewidth=2)
ax2.set_ylabel("Accuracy", fontsize=18)
ax2.set_xlabel("Steps", fontsize=18)
# ax2.set_title("Fit", )
ax2.legend(title="Fit", fontsize=16, title_fontsize=18)

ax2.set_xticklabels(["$10^0$", "$10^1$", "$10^2$", "$10^3$", "$10^4$", "$10^5$"])
# ax2.set_xlim(10, 100)

fig.tight_layout(pad=0.25)

# Already in log scale
# train_ys, test_ys

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))

ax1.plot(run["_step"], run["test/acc"], label="Test", color=RED, linewidth=2)
ax1.plot(run["_step"], run["train/acc"], label="Train", color=BLUE, linewidth=2)
ax1.set_ylabel("Accuracy", fontsize=18)
# ax1.set_xlabel("Steps", fontsize=18)
ax1.set_xticklabels(["", "", "", "", "", ""], color="white")
ax1.set_xscale("log")
ax1.legend(title="Truth", fontsize=16, title_fontsize=18)

min_step, max_step = og_df["_step"].min(), 100 # run["_step"].max()

ts = np.linspace(12, max_step, 1000)
train_ys = [model(t).detach().numpy() for t in ts]
test_ys = [model.test(t).detach().numpy() for t in ts]
ax2.plot(ts, train_ys, label="Train", color=BLUE, linewidth=2)
ax2.plot(ts, test_ys, label="Test", color=RED, linewidth=2)
ax2.set_ylabel("Accuracy", fontsize=18)
ax2.set_xlabel("Steps", fontsize=18)
# ax2.set_title("Fit", )
ax2.legend(title="Fit", fontsize=16, title_fontsize=18)
ax2.set_xticklabels(["$10^0$", "$10^1$", "$10^2$", "$10^3$", "$10^4$", "$10^5$"])
# ax2.set_xlim(10, 100)

fig.tight_layout(pad=0.25)

# Already in log scale
# train_ys, test_ys

In [None]:
# Ignore any d_model < 50
df_cleaned = df.loc[df["d_model"] >= 50, :]
d_models = df_cleaned.loc[:, "d_model"].unique()

# Scaling analysis
fig = plt.figure(figsize=(15, 4))
ax = fig.add_subplot(111)

colors = [BLUE, RED, "green"]
y_max = 0

for i in range(3):
    slice = df_cleaned.loc[:, f"pattern_{i}/onset"]
    y_max = max(y_max, slice.max())
    ax.plot(d_models, slice, label=f"", color=colors[i], linewidth=2)

ax.set_xlabel("d_model", fontsize=18)
ax.set_ylabel("Onset", fontsize=18)


# Fit a power-law to the onsets 
from scipy.optimize import curve_fit

def power_law(x, a, b):
    return a * x**b

def fit_power_law(x, y):
    popt, pcov = curve_fit(power_law, x, y)
    return popt


CUTOFF = 175

# Fit power law to onset
for i in range(3):
    # Train up to a specific point
    df_to_fit = df_cleaned.loc[df_cleaned["d_model"] <= CUTOFF, :]
    d_models_to_fit = df_to_fit.loc[:, "d_model"].unique()

    onset_popt = fit_power_law(d_models_to_fit, df_to_fit.loc[:, f"pattern_{i}/onset"])
    exponent = round(onset_popt[1], 2)
    ax.plot(d_models, power_law(d_models, *onset_popt), label=f"$\\nu_{i} = {exponent}$", color=colors[i], linestyle="--", linewidth=2)

ax.vlines(CUTOFF, 0, y_max * 1.05, color="grey", linestyle="--", linewidth=2)
ax.set_xlabel("Embedding dim.", fontsize=18)
ax.set_ylim(0, y_max * 1.05)

ax.legend()

In [None]:
from scipy.ndimage import gaussian_filter1d

r_uncorrupted = gaussian_filter1d(run["uncorrupted/acc"], 2.)
r_corrupted = gaussian_filter1d(run["corrupted/acc"], 2.)

r_uncorrupted

In [None]:
plt.plot(steps, run["corrupted/acc"])

In [None]:
# fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 4))

type_1 = [(model.patterns[0](t) / model.patterns[0].strength).detach().float() for t in ts]
type_2 = [(model.patterns[1](t) / model.patterns[1].strength).detach().float() for t in ts]
type_3 = [(model.patterns[2](t) / model.patterns[2].strength).detach().float() for t in ts]

TS = ts * 500_000 / 100

# Plot 1: Uncorrupted data / Type 1 Pattern
ax1.plot(steps, r_uncorrupted, label="Uncorrupted", color=RED, linewidth=2)
ax1.plot(TS, type_1, label="Prediction", color=BLUE, linestyle="-", linewidth=2, alpha=0.75)
ax1.legend(title="Type 1", fontsize=12, title_fontsize=16, loc="upper left")

# TODO: Plot pattern 1
# ax1.plot(steps, )

# Plot 2: Corrupted data / Type 2 Pattern
ax2.plot(steps, r_corrupted, label="Corrupted", color=RED, linewidth=2)
ax2.plot(TS, type_2, label="Prediction", color=BLUE, linestyle="-", linewidth=2, alpha=0.75)
ax2.legend(title="Type 2", fontsize=12, title_fontsize=16, loc="upper left")


# Plot 3: Type 3 Pattern

# ax3.plot(TS, type_3, label="Prediction (TODO)", color=BLUE, linestyle="--", linewidth=2)
# ax3.legend(title="Type 3", fontsize=12, title_fontsize=16, loc="upper left")

ax1.set_ylabel("Accuracy", fontsize=18)

# for ax in [ax1, ax2, ax3]:
for ax in [ax1, ax2]:
    ax.set_xlabel("Steps", fontsize=18)
    ax.set_xscale("log")
    ax.set_xticklabels(["", "", "$10^0$", "$10^1$", "$10^2$", "$10^3$", "$10^4$", "$10^5$", ])

plt.savefig("../figures/pattern-predictions.pdf", bbox_inches="tight")

In [None]:
INTERP_SWEEPS = ["kodd01ka", "wecya83q", "wqnakkjd"]  # "awxzpem1"
interp_sweep = get_history(*INTERP_SWEEPS, project="mnist-grokking", allow_duplicates=True, combine_seeds=True)
# interp_sweep.drop(["weight/cos_sim_with_init", "test/efficiency", "train/efficiency", "weight/dist_from_init"])
interp_sweep

In [None]:
histories = interp_sweep.copy()
unique_cols = ["lr_factor"]

assert (
    len(unique_cols) == 1
), "Can only combine seeds if there is a single unique column"

unique_col = unique_cols[0]
unique_vals = histories[unique_col].unique()

for val in unique_vals:
    runs = histories[histories[unique_col] == val]
    seeds = runs.seed.unique()

    if len(seeds) > 1:
        # Define the metrics that need to be averaged
        metrics = ["train/acc", "test/acc", "train/loss", "test/loss", "corrupted/acc", "uncorrupted/acc"]
        for metric in metrics:
            # Calculate the mean value for each metric and _step
            means_groups = runs.groupby("_step")[metric]

            means = means_groups.apply(
                lambda x: x.ffill().bfill().mean() if x.isna().any() else x.mean()
            )

            if metric == "corrupted/acc":
                print(means)

            # Update the histories dataframe
            for _step, mean_value in means.items():
                mask = (histories[unique_col] == val) & (
                    histories._step == _step
                )
                histories.loc[mask, metric] = mean_value

# Remove duplicate rows
histories = histories.drop_duplicates(subset=[*unique_cols, "_step"])