In [None]:
!pip install numpy torch sympy mod blobfile pandas seaborn matplotlib tqdm einops wandb

import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from contextlib import suppress
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Callable, Literal, Optional, Union, Tuple, List
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
import wandb
from tqdm.notebook import tqdm
import ipywidgets as widgets
import wandb

import matplotlib as mpl
from matplotlib.colors import LogNorm
import seaborn as sns
import matplotlib.pyplot as plt

from patterns.dataset import ModularArithmetic, Operator
from patterns.transformer import Transformer
from patterns.utils import generate_run_name
from patterns.learner import Config

from toy_models.fit import rescale_run, Pattern, PatternLearningModel

from unifying.sweep import get_history, handle_outliers
from unifying.plotting import BLUE, RED

DEFAULT_MODULUS = 113
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sns.set_theme(style="darkgrid")

In [None]:
from contextlib import suppress
from copy import deepcopy
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Callable, List, Literal, Optional, Tuple, Union

import ipywidgets as widgets
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from matplotlib.colors import LogNorm
from torch import nn, optim
from tqdm.notebook import tqdm

import wandb


def rescale_run(run, new_max=1.0, log=True):
    # Changes the steps to fit in the range [0, 100] (following a log scale)
    run = run.copy()
    max_ = run["_step"].max()

    if log:
        max_ = np.log(max_)
        run["_step"] = np.log(run["_step"]) / max_ * new_max
    else:
        run["_step"] = run["_step"] / max_ * new_max

    return run


class Pattern(nn.Module):
    def __init__(self, max_time: float = 1.0, onset: Optional[float] = None, generalization: Optional[float] = None, strength: Optional[float] = None, speed: Optional[float] = None):
        # 4 scalar parameters: strength, speed, onset, generalization
        super().__init__()

        strength = strength or torch.rand(1)[0]
        speed = speed or torch.rand(1)[0] * 10 / max_time
        onset = onset or torch.rand(1)[0] * max_time
        generalization = generalization or torch.rand(1)[0]

        self._strength = nn.Parameter(self._inv_sigmoid(torch.tensor(strength)))
        self.speed = nn.Parameter(torch.tensor(speed))
        self.onset = nn.Parameter(torch.tensor(onset))
        self._generalization = nn.Parameter(torch.log(torch.tensor(generalization)))

    @staticmethod
    def _inv_sigmoid(x):
        return torch.log(x / (1 - x))
    
    @property
    def strength(self):
        return F.sigmoid(self._strength)
    
    @strength.setter
    def strength(self, value):
        self._strength = self._inv_sigmoid(value)

    @property
    def generalization(self):
        return torch.exp(self._generalization)
    
    @generalization.setter
    def generalization(self, value):
        self._generalization = torch.log(value)

    def forward(self, t):
        return self.strength * F.sigmoid(self.speed * (t - self.onset))

    def __repr__(self):
        return f"Pattern(strength={self.strength.data.float()}, speed={self.speed.data.float()}, onset={self.onset.data.float()}, generalization={self.generalization.data.float()})"


class PatternLearningModel(nn.Module):
    def __init__(self, num_patterns: int = 3, max_time=1.0):
        super().__init__()
        self.num_patterns = num_patterns
        self.patterns = nn.ModuleList([
            Pattern(
                max_time, 
                onset=max_time * (i + 1) / (num_patterns + 1),
                speed=10./max_time,
                generalization=0.5,
                strength=0.5
            ) 
            for i in range(num_patterns)
        ])
        self.max_time = max_time

        self.binary_mask = torch.tensor(
            [
                [int(i) for i in bin(j)[2:].zfill(num_patterns)]
                for j in range(2**num_patterns)
            ]
        ).float()

        self.counts = self.binary_mask.sum(dim=1)

    def forward(self, t):
        return 1 - torch.prod(1 - self.predictivenesses(t), dim=0)

    # def usages(self, t):
    #     preds = self.predictivenesses(t)
    #     usages = torch.prod(preds.T * self.binary_mask + (1 - preds.T) * (1 - self.binary_mask), dim=1)
    #     return usages

    def gs(self):
        return torch.stack([p.generalization for p in self.patterns])

    # def generalizations(self):
    #     generalizations = torch.sum(self.gs().T * self.binary_mask, dim=1) / self.counts
    #     generalizations[0] = 0
    #     return generalizations

    def predictivenesses(self, t):
        return torch.stack([p(t) for p in self.patterns])

    def forward(self, t):
        prod = 1

        for p in self.patterns:
            prod *= 1 - p(t)

        return 1 - prod

    def usages(self, t):
        preds = [p(t) for p in self.patterns]
        usages = torch.ones(2**self.num_patterns)

        for i in range(2**self.num_patterns):
            for j in range(self.num_patterns):
                if i & (1 << j):
                    usages[i] *= preds[j]
                else:
                    usages[i] *= 1 - preds[j]

        return usages

    def generalizations(self):
        generalizations = torch.zeros(2**self.num_patterns)

        for i in range(2**self.num_patterns):
            count = 0
            total = 0

            for j in range(self.num_patterns):
                if i & (1 << j):
                    # print(i, j, self.patterns[j].generalization, generalizations[i])
                    generalizations[i] += self.patterns[j].generalization
                    count += 1
                
                total += self.patterns[j].generalization

            # if count > 0:
            #     generalizations[i] /= count

            if total > 0:
                generalizations[i] /= total

        return generalizations

    def test(self, t):
        return torch.sum(self.generalizations() * self.usages(t), dim=0)

    def fit(self, run, lr=0.1, num_epochs=1000, callback=None, callback_ivl=100):
        ts = torch.tensor(run._step.values).float()

        train_ys = torch.tensor(run["train/acc"].values).float()
        test_ys = torch.tensor(run["test/acc"].values).float()

        optimizer = optim.Adam(self.parameters(), lr=lr)
        criterion = nn.MSELoss()
        # Cross-entropy
        eps = 1e-6
        # criterion = lambda preds, ys: -torch.sum(ys * torch.log(preds + eps) + (1 - ys) * torch.log(1 - preds + eps))
        callback(self)

        for epoch in tqdm(range(num_epochs)):
            train_preds = torch.zeros_like(train_ys)
            test_preds = torch.zeros_like(test_ys)

            optimizer.zero_grad()

            for i, t in enumerate(ts):
                train_preds[i] = self(t)
                test_preds[i] = self.test(t)

            loss = criterion(train_preds, train_ys) + criterion(test_preds, test_ys)
            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch} - loss: {loss.item()}")
            
            if callback is not None and epoch % callback_ivl == 0:
                callback(self)

        return self

    def to_dict(self):
        """To a dataframe, sorting patterns by onset time"""
        patterns = sorted(self.patterns, key=lambda p: p.onset.data)
        d = {}

        for i, p in enumerate(patterns):
            d[f"pattern_{i}/strength"] = p.strength.data
            d[f"pattern_{i}/speed"] = p.speed.data
            d[f"pattern_{i}/onset"] = p.onset.data
            d[f"pattern_{i}/generalization"] = p.generalization.data

        return d
    
    def rescale(self, max_time):
        """Rescale the model to a new max time"""
        scaling_factor = max_time / self.max_time

        for p in self.patterns:
            p.onset.data /= scaling_factor
            p.speed.data *= scaling_factor

        self.max_time = max_time

    def __repr__(self):
        return f"PatternLearningModel({self.to_dict()})"


In [None]:
torch.manual_seed(2)
pl_model = PatternLearningModel(max_time=100.)

def plot_patterns(pl_model, run):
    ts = run["_step"].values
    train_preds = [pl_model(t).detach().numpy() for t in ts]
    test_preds = [pl_model.test(t).detach().numpy() for t in ts]
    train_ys = torch.tensor(run["train/acc"].values).float()
    test_ys = torch.tensor(run["test/acc"].values).float()
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 5))

    axes[0].plot(ts, train_preds, label="train", color="blue")
    axes[0].plot(ts, test_preds, label="test", color="red")

    axes[1].plot(ts, train_ys, label="train", color="blue")
    axes[1].plot(ts, test_ys, label="test", color="red")

    axes[0].set_title("Predictions")
    axes[1].set_title("True values")

    # axes[0].set_xscale("log")
    # axes[1].set_xscale("log")


In [None]:
def _plot_patterns(pl_model):
    plot_patterns(pl_model, rescaled_run)
    plt.show()

    print(pl_model.patterns)


pl_model = PatternLearningModel(num_patterns=3, max_time=100.)
# pl_model.fit(rescaled_run, lr=0.1, callback=_plot_patterns, callback_ivl=10, num_epochs=500)

In [None]:
VARIABLE_COLS = [
    "test/acc",
    "train/acc",
    "test/loss",
    "train/loss",
    "_step",
    "weight/norm",
    "test/efficiency",
    "train/efficiency",
    "weight/dist_from_init",
    "weight/cos_sim_with_init",
]

def fit_sweep(df: pd.DataFrame, unique_col: str, lr=0.1, max_time=1.0, num_patterns=3, num_epochs=500, **kwargs):
    unique_vals = df.loc[:, unique_col].unique()

    variable_cols = [c for c in df.columns if c in VARIABLE_COLS]
    hyperparams: dict = (
        df.loc[0, :]
        .drop(columns=[unique_col, *variable_cols])
        .to_dict()
    )

    wandb.init(
        project="fit-toy-model",
    )

    try:
        for unique_val in tqdm(unique_vals):
            run = df.loc[df[unique_col] == unique_val]
            rescaled_run = rescale_run(run, new_max=max_time)

            pl_model = PatternLearningModel(
                num_patterns=num_patterns, 
                max_time=max_time
            )

            def _plot_patterns(pl_model):
                plot_patterns(pl_model, rescaled_run)
                plt.show()

            pl_model.fit(rescaled_run, lr=lr, num_epochs=num_epochs, callback=_plot_patterns)
            pl_model.rescale(1.)

            wandb.log({unique_col: unique_val, **pl_model.to_dict(), **hyperparams, **kwargs})

            _plot_patterns(pl_model)
            plt.show()

    except KeyboardInterrupt:
        wandb.finish()


# Fit the sweeps

In [None]:
WD_SWEEP_ID = "ib21hnk1"
LN_SWEEP_ID = "8783j1j4"
DM_SWEEP_ID = "l1b2mmci"

dm_sweep_2 = pd.read_csv("../unifying/mw_sweep.csv")

SWEEP_IDS = ["mw-other"] # [WD_SWEEP_ID, LN_SWEEP_ID, DM_SWEEP_ID]
UNIQUE_COLS = ["d_model"] # ["weight_decay", "frac_label_noise", "d_model"]
SWEEPS = [dm_sweep_2] # [get_history(sweep_id, unique_cols=unique_col) for sweep_id, unique_col in zip(SWEEP_IDS, UNIQUE_COLS)]

In [None]:
dm_sweep_2.fillna(method="ffill", inplace=True)
dm_sweep_2.loc[0, :].to_dict()

In [None]:
for sweep, sweep_id, unique_col in zip(SWEEPS, SWEEP_IDS, UNIQUE_COLS): 
    fit_sweep(sweep, unique_col, num_patterns=3, max_time=100.0, log=True, sweep=sweep_id, num_epochs=500)

In [None]:
api = wandb.Api()
runs = api.runs(f"{ENTITY}/fit-toy-model")
[run for run in runs]

In [None]:
mw_fit_run_id = "1gthqq5"
run = runs[0]
df = run.history()

col = "d_model"
unique_vals = df.loc[:, col].unique()
unique_vals

In [None]:
D_MODEL = 115

model_entry = df.loc[df[col] == D_MODEL, :].iloc[0, :]
model_entry

In [None]:
def df_row_to_toy_model(row):
    model = PatternLearningModel(max_time=100.)

    for i, pattern in enumerate(model.patterns):
        pattern.onset.data = torch.tensor(row[f"pattern_{i}/onset"])
        pattern.speed.data = torch.tensor(row[f"pattern_{i}/speed"])
        pattern.strength.data = torch.tensor(row[f"pattern_{i}/strength"])
        pattern.generalization.data = torch.tensor(row[f"pattern_{i}/generalization"])

    return model

model = df_row_to_toy_model(model_entry)
model.rescale(100)
model


In [None]:
# Get corresponding original run
og_df = get_history(DM_SWEEP_ID, unique_cols="d_model")
run = og_df.loc[og_df.d_model==D_MODEL,:] #.plot(x="_step", y="test/acc")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 4))

ax1.plot(run["_step"], run["test/acc"], label="Test", color=RED, linewidth=2)
ax1.plot(run["_step"], run["train/acc"], label="Train", color=BLUE, linewidth=2)
ax1.set_ylabel("Accuracy", fontsize=18)
ax1.set_xlabel("Step", fontsize=18)
ax1.set_xscale("log")
ax1.legend()
ax1.set_title("Truth")

min_step, max_step = og_df["_step"].min(), 10000 # run["_step"].max()

ts = np.linspace(min_step, max_step, 1000)
train_ys = [model(t).detach().numpy() for t in ts]
test_ys = [model.test(t).detach().numpy() for t in ts]
ax2.plot(ts, train_ys, label="Train", color=BLUE, linewidth=2)
ax2.plot(ts, test_ys, label="Test", color=RED, linewidth=2)
ax2.set_ylabel("Accuracy", fontsize=18)
ax2.set_xlabel("Step", fontsize=18)
ax2.set_title("Fit")

# Already in log scale
# train_ys, test_ys

In [None]:
# Ignore any d_model < 50
df_cleaned = df.loc[df["d_model"] >= 50, :]
d_models = df_cleaned.loc[:, "d_model"].unique()

# Scaling analysis
fig = plt.figure(figsize=(15, 4))
ax = fig.add_subplot(111)

colors = [BLUE, RED, "green"]
y_max = 0

for i in range(3):
    slice = df_cleaned.loc[:, f"pattern_{i}/onset"]
    y_max = max(y_max, slice.max())
    ax.plot(d_models, slice, label=f"", color=colors[i], linewidth=2)

ax.set_xlabel("d_model", fontsize=18)
ax.set_ylabel("Onset", fontsize=18)


# Fit a power-law to the onsets 
from scipy.optimize import curve_fit

def power_law(x, a, b):
    return a * x**b

def fit_power_law(x, y):
    popt, pcov = curve_fit(power_law, x, y)
    return popt


CUTOFF = 175

# Fit power law to onset
for i in range(3):
    # Train up to a specific point
    df_to_fit = df_cleaned.loc[df_cleaned["d_model"] <= CUTOFF, :]
    d_models_to_fit = df_to_fit.loc[:, "d_model"].unique()

    onset_popt = fit_power_law(d_models_to_fit, df_to_fit.loc[:, f"pattern_{i}/onset"])
    exponent = round(onset_popt[1], 2)
    ax.plot(d_models, power_law(d_models, *onset_popt), label=f"$\\nu_{i} = {exponent}$", color=colors[i], linestyle="--", linewidth=2)

ax.vlines(CUTOFF, 0, y_max * 1.05, color="grey", linestyle="--", linewidth=2)
ax.set_xlabel("Embedding dim.", fontsize=18)
ax.set_ylim(0, y_max * 1.05)

ax.legend()

In [None]:
# Let's see if we can fit 