In [None]:
!pip install pytorch-lightning --quiet
!pip install -q git+https://www.github.com/google/neural-tangents

import os
import sys

if os.path.isdir('/content/eigenlearning'):
    !rm -r '/content/eigenlearning'
## [REDACTED EIGENLEARNING LIBRARY IMPORT]
sys.path.insert(0,'/content/eigenlearning')

!pip3 install pickle5
import pickle5 as pickle

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import neural_tangents as nt
import numpy as np
import jax
from jax import numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
import time

import image_datasets
import measures
import powerlaws
import utils

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_lightning as pl
import torchmetrics

from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

import shutil

def rcsetup():
    plt.rc("figure", dpi=150, facecolor=(1, 1, 1), figsize=(6, 3.5))
    plt.rc("font", family='stixgeneral', size=18)
    plt.rc("axes", titlesize=19)
    # plt.rc("axes", facecolor=(1, .99, .95))
    plt.rc("mathtext", fontset='cm')

def get_plot_color(ind, ncolors=10):
    from matplotlib.colors import hsv_to_rgb
    colorlist = ["xkcd:blue", "xkcd:pumpkin", "xkcd:moss", "xkcd:lavender", "xkcd:goldenrod", "xkcd:puce", "xkcd:crimson"]
    colorlist = [hsv_to_rgb((h,1,.7)) for h in np.linspace(0, 0.8, ncolors)]
    ncolors = len(colorlist)
    return colorlist[ind%ncolors]

plt.rcdefaults()
rcsetup()

# Data generation

In [None]:
def powerlaw_lambdas_v2s(alpha, beta, M, normalize_lambdas=False):
    idxs = np.arange(1,M+1) * 1.

    lambdas = idxs**-alpha
    if normalize_lambdas:
        lambdas /= lambdas.sum()

    v2s = idxs**-beta
    v2s /= v2s.sum()

    f_terms = {i:np.sqrt(v2s[i]) for i in range(len(v2s))}

    return lambdas, v2s, f_terms

def theory_mses(lambdas, f_terms, deltas_over_n, ns, M):
    results = {}
    for n_train in ns:
        ridges = n_train * deltas_over_n
        for ridge in ridges:
            preds = measures.learning_measure_predictions(None, None, n=n_train, ridge=ridge,
                                                          f_terms=f_terms, lambdas=lambdas, mults=1)
            results[(n_train, ridge)] = preds
            print('.', end='')
        print()
    results['M'] = M
    results['ns'] = ns
    results['eigenvals'] = lambdas
    results['deltas_over_n'] = deltas_over_n
    return results

## Synthetic experiments

In [None]:
alpha = 2
beta = alpha
M = 500000
lambdas, v2s, f_terms = powerlaw_lambdas_v2s(alpha, beta, M)

ns = [16, 125, 1000, 8000, 256000]
deltas_over_n = np.array([10**i for i in np.linspace(2, -12, 30)])

theory_results = theory_mses(lambdas, f_terms, deltas_over_n, ns, M)
fname = "r_spectrum_" + str(alpha)
with open("/content/drive/My Drive/eigenlearning DB/kernel/{}_theory_30k_orders2.pickle".format(fname), "wb") as handle:
    pickle.dump(theory_results, handle, protocol=pickle.HIGHEST_PROTOCOL)

## MNIST KRR (theory vs empirical)

In [None]:
from jax import random
from image_datasets import get_image_eigendata

M = 30000

if M <= 16000:
    _, _, kernel_fn = utils.get_net_fns(width=500, d_out=1, n_hidden_layers=4)
    key, subkey = random.split(np.uint32([0,17]), 2)
    classes = [[0,1,2,3,4], [5,6,7,8,9]]
    eigendata = get_image_eigendata('mnist', M, kernel_fn, classes)
if M == 30000:
    with open("/content/drive/My Drive/eigenlearning DB/kernel/mnist_30k_eigendata.pickle", "rb") as handle:
        eigendata = pickle.load(handle)

lambdas = eigendata['eigenvals']
f_terms = eigendata['f_terms']
deltas_over_n = np.array([10**i for i in np.linspace(2, -8, 30)])
ns = [16, 125, 1000, 8000]

results = theory_mses(lambdas, f_terms, deltas_over_n, ns, M)
with open("/content/drive/My Drive/eigenlearning DB/kernel/r_mnist_theory_30k_04_59.pickle", "wb") as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
ns = [16, 125, 1000, 8000]
deltas_over_n = np.array([10**i for i in np.linspace(2, -8, 30)])
n_trials = 10

results = {}
results['ns'] = ns
results['deltas_over_n'] = deltas_over_n
results['n_trials'] = n_trials

classes = [[0,1,2,3,4],[5,6,7,8,9]]
net_fns = utils.get_net_fns(width=500, d_out=1, n_hidden_layers=4)
for n in ns:
    for delta_over_n in deltas_over_n:
        delta = delta_over_n * n
        stats = measures.learning_measure_statistics(net_fns, 'mnist', n,
                                            classes=classes, pred_type='kernel',
                                            n_trials=n_trials, n_test=2000, ridge=delta,
                                            compute_train_measures=True)
        results[(n, delta)] = stats['kernel']
        print('.', end='')
    print()

with open("/content/drive/My Drive/eigenlearning DB/kernel/r_mnist_04_59.pickle", "wb") as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Neural networks

In [None]:
def modified_resnet():
    model = resnet18(num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

class SoftError(torchmetrics.Metric):
    full_state_update = False

    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        soft = preds[range(preds.shape[0]), target]

        self.correct += target.numel() - torch.sum(soft)
        self.total += target.numel()

    def compute(self):
        return self.correct / self.total


class LitResnet(pl.LightningModule):

    def __init__(self, lr=0.1, dataset_size=50000):
        super().__init__()

        self.rng = torch.Generator().manual_seed(40)
        self.lr = lr

        self.n_classes = 10
        self.dims = (3, 32, 32)
        self.datasize = dataset_size

        self.model = modified_resnet()

        self.test_error = 1 - torchmetrics.Accuracy()
        self.train_error = 1 - torchmetrics.Accuracy()
        self.soft_error = SoftError()

        self.train_step_num = 0
        

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.train_error.update(preds, y)

        self.log("TrainLoss", loss)
        self.train_step_num += 1
        return loss

    def training_epoch_end(self, outs):
        pass

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        self.test_error.update(preds, y)
        self.soft_error.update(torch.exp(logits), y)

        self.log("TestError", self.test_error, prog_bar=True)
        self.log("SoftError", self.soft_error)
        if self.train_step_num > 0:
            self.log("TrainError", self.train_error)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        y_hat = self.model(x)
        return y_hat

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.lr,
            momentum=0.9,
        )
        max_epochs = MAX_STEPS / self.datasize * BATCH_SZ
        scheduler_dict = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=max_epochs
            ),
            "interval": "epoch",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download data
        CIFAR10("cifar10/", download=True)
    
    def setup(self, stage):
        train_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
        ])
        self.traindata = CIFAR10("cifar10/", train=True, download=True,
                                 transform=train_transforms)
        self.testdata = CIFAR10("cifar10/", train=False, download=True,
                                transform=transforms.ToTensor())

    def train_dataloader(self):
        small_trainset, _ = random_split(self.traindata,
                                         [self.datasize, len(self.traindata)-self.datasize])
        return DataLoader(small_trainset, batch_size=BATCH_SZ, num_workers=2, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.testdata, batch_size=BATCH_SZ, num_workers=2)


from pytorch_lightning.loops.epoch import TrainingEpochLoop
class LogValLoop(TrainingEpochLoop):

    def set_val_schedule(self, val_schedule):
        self.val_sched = val_schedule
    
    def _should_check_val_fx(self) -> bool:
        """Decide if we should run validation."""
        if not self._should_check_val_epoch():
            return False

        # val_check_batch is inf for iterable datasets with no length defined
        is_infinite_dataset = self.trainer.val_check_batch == float("inf")
        is_last_batch = self.batch_progress.is_last_batch
        if is_last_batch and is_infinite_dataset:
            return True

        if self.trainer.should_stop:
            return True

        # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
        is_val_check_batch = is_last_batch
        if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
            is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0
        elif self.trainer.val_check_batch != float("inf"):
            # if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
            # else condition it based on the batch_idx of the current epoch
            current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx
            is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0

        
        if self.val_sched is not None:
            threshold = self.val_sched[0] if len(self.val_sched)>0 else np.inf
            if self.total_batch_idx < threshold:
                return False
            else:
                self.val_sched = self.val_sched[1:]
                return True
        return is_val_check_batch

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SZ = 128 if torch.cuda.is_available() else 64
MAX_STEPS = 15000

torch.manual_seed(42)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [None]:
count = 0
max_steps = [4000, 6000, 10000, 14000]
ns =  [500, 2000, 8000, 32000]
n_trials = 5
val_sched = np.geomspace(10, 13800, 50).astype(int)

exp_details = {
    "ns": ns,
    "n_trials": n_trials,
    "val_sched": val_sched
}
with open("/content/drive/My Drive/eigenlearning DB/NN_DB_exp_details.pickle", "wb") as handle:
    pickle.dump(exp_details, handle, protocol=pickle.HIGHEST_PROTOCOL)

for run in range(n_trials):
    for i, n in enumerate(ns):
        model = LitResnet(lr=0.1, dataset_size=n)

        loggers = [TensorBoardLogger(save_dir="."),
                CSVLogger(save_dir=".")]

        trainer = pl.Trainer(
            accelerator="auto",
            devices=1 if torch.cuda.is_available() else None,
            callbacks=[TQDMProgressBar(refresh_rate=1),
                       LearningRateMonitor(),],
            logger=loggers,
            log_every_n_steps=20,
            val_check_interval=10, check_val_every_n_epoch=None,
        )
        val_loop = LogValLoop(max_steps=max_steps[i])
        val_loop.set_val_schedule(val_sched)
        val_loop.trainer = trainer
        trainer.fit_loop.connect(epoch_loop=val_loop)
        trainer.fit(model)

        exp_name = '{}'.format(n)
        version = count
        shutil.move("/content/lightning_logs/version_{}/metrics.csv".format(version),
                    "/content/drive/My Drive/eigenlearning DB/metrics_{}_{}.csv".format(exp_name, run)) 
        count += 1

        del model

In [None]:
# !rm -rf lightning_logs

# Plotting

In [None]:
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from matplotlib.colors import hsv_to_rgb

rcsetup()

## Synthetic experiments

In [None]:
expo=2
fname = "r_spectrum_" + str(expo)
with open("/content/drive/My Drive/eigenlearning DB/kernel/{}_theory_30k_orders2.pickle".format(fname), "rb") as handle:
    results_theory = pickle.load(handle)
            
fig, ax = plt.subplots()

M = results_theory['M']
lambdas = results_theory['eigenvals']
deltas_over_n = results_theory['deltas_over_n']
show_ns = results_theory['ns']

skip = []
label_pos = [.76, .565, .365, .165]
for i, n in enumerate(show_ns):
    if n in skip:
        continue
    mses, mses_std, train_mses, train_mses_std = [np.zeros(len(deltas_over_n)) for _ in range(4)]
    mses_theory, train_mses_theory = [np.zeros(len(deltas_over_n)) for _ in range(2)]
    for j, delta in enumerate(deltas_over_n * n):    
        result = results_theory[(n,delta)]
        mses_theory[j] = result['mse']
        train_mses_theory[j] = result['train_mse']

    # Correction
    mses_theory = 1/(M-n+1e-4) * (M * mses_theory - n * train_mses_theory)

    tau_eff = 1 / deltas_over_n

    if i < len(show_ns) - 1:
        color = get_plot_color(i, len(show_ns))
        ax.text(0.99, label_pos[i], r'$n={}$'.format(n), color=color, ha='right', transform=ax.transAxes, fontsize=15)
        kappa_0 = ((np.pi / expo) / np.sin(np.pi / expo) / n) ** expo
        ax.axvline(1/kappa_0, ls="--", color=color, alpha=0.23)

        ax.plot(tau_eff, mses_theory, color=color, ls='-')
        ax.plot(tau_eff, train_mses_theory, color=color, ls=':')
    else:
        color = (0, 0, 0, .75)
        ax.plot(tau_eff, mses_theory, color=color, ls='-.')


ax.set_xlabel(r'$\tau_\mathrm{eff}$')
ax.set_ylabel(r'$\mathcal{E}(f)$')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_ylim(4e-5,1.3e0)
ax.set_xlim(1e-2,1e12)
plt.legend(
        (
            Line2D([0], [0], color='xkcd:navy', alpha=.75, ls=":"),
            Line2D([0], [0], color='xkcd:navy', alpha=.75, ls="-"),
            Line2D([0], [0], color='xkcd:navy', alpha=.65, ls="-."),),
        (
            "train",
            "test",
            r"$n\to\infty$"),
        framealpha=0.95,
        fontsize=15.5,
        loc='lower left',

    )
plt.axhline(1, ls='--', color='black', lw=1, alpha=0.4)
plt.tight_layout()
# plt.show()
plt.savefig("KRR_synthetic_deep_bootstrap.pdf", bbox_inches='tight')

## MNIST KRR (theory vs empirical)

In [None]:
with open("/content/drive/My Drive/eigenlearning DB/kernel/r_mnist_theory_30k_04_59_orders2.pickle", "rb") as handle:
    results_theory = pickle.load(handle)
with open("/content/drive/My Drive/eigenlearning DB/kernel/r_mnist_orders2_04_59.pickle", "rb") as handle:
    results = pickle.load(handle)
            
fig, ax = plt.subplots()

M = results_theory['M']
lambdas = results_theory['eigenvals']
deltas_over_n = results_theory['deltas_over_n']
show_ns = results_theory['ns']

skip = []
label_pos = [.89, .7, .43, .15]
for i, n in enumerate(show_ns):
    if n in skip:
        continue
    mses, mses_std, train_mses, train_mses_std = [np.zeros(len(deltas_over_n)) for _ in range(4)]
    mses_theory, train_mses_theory = [np.zeros(len(deltas_over_n)) for _ in range(2)]
    for j, delta in enumerate(deltas_over_n * n):
        result = results[(n,delta)]
        mses[j], mses_std[j] = result['mse']
        train_mses[j], train_mses_std[j] = result['train_mse']
    
        result = results_theory[(n,delta)]
        mses_theory[j] = result['mse']
        train_mses_theory[j] = result['train_mse']

    color = get_plot_color(i, len(show_ns))
    # Correction
    mses_theory = 1/(M-n) * (M * mses_theory - n * train_mses_theory)
    tau_eff = 1 / deltas_over_n

    kappa_0 = measures.find_C(n, lambdas)
    ax.axvline(1/kappa_0, ls="--", color=color, alpha=0.4)

    ax.plot(tau_eff, mses_theory, color=color, label=n)
    ax.plot(tau_eff, train_mses_theory, color=color)

    ax.plot(tau_eff, mses, ls="-.", color=color, alpha=0.5)
    ax.fill_between(tau_eff, mses-mses_std, mses+mses_std, color=color, alpha=0.13)
    ax.plot(tau_eff, train_mses, ls=":", color=color, alpha=0.5)
    ax.fill_between(tau_eff, train_mses-train_mses_std, train_mses+train_mses_std, color=color, alpha=0.13)

    ax.text(0.99, label_pos[i], r'$n={}$'.format(n), color=color, ha='right', transform=ax.transAxes, fontsize=14)

ax.text(0.01, .92, "B", color='black', transform=ax.transAxes, fontsize=15, fontweight='bold')
ax.set_xlabel(r'$\tau_\mathrm{eff}$')
ax.set_ylabel(r'$\mathcal{E}(f)$')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_ylim(10e-2,1.3e0)
ax.set_xlim(1e-2,1e6)
plt.legend(
    (
        (Line2D([0], [0], color='xkcd:navy', alpha=.5, ls=":"), Patch(color='k', alpha=0.13, lw=0)),
        (Line2D([0], [0], color='xkcd:navy', alpha=.5, ls="-."), Patch(color='k', alpha=0.13, lw=0)),
        Line2D([0], [0], color='xkcd:navy'),),
    (
        "KRR train",
        "KRR test",
        "theory",),
     framealpha=0.95, fontsize=13)
plt.axhline(1, ls='--', color='black', lw=1, alpha=0.4)
plt.tight_layout()
plt.show()
# plt.savefig("KRR_deep_bootstrap.pdf", bbox_inches='tight')

## Neural networks

In [None]:
import csv
from scipy.ndimage import gaussian_filter1d

with open("/content/drive/My Drive/eigenlearning DB/NN_DB_exp_details.pickle", "rb") as handle:
    exp_details = pickle.load(handle)

ns = exp_details['ns']
n_trials = exp_details['n_trials']
trainsteps = exp_details['val_sched']

all_data = {}
skip = []
for n in ns:
    for run in range(n_trials):
        if run in skip:
            continue
        exp_name = str(n)
        fname = "/content/drive/My Drive/eigenlearning DB/metrics_{}_{}.csv".format(exp_name, run)
        with open(fname, 'r', newline='') as csvfile:
            reader = csv.reader(csvfile, delimiter=',', quotechar='|')
            lines = list(reader)
            header, table = lines[0], lines[1:]
            run_data = {}
            for i, metric in enumerate(header):
                steps = [float(line[1]) for line in table if line[i]!='']
                data = [float(line[i]) for line in table if line[i]!='']
                run_data[metric] = (steps, data)
            all_data["{}_{}".format(exp_name, run)] = run_data

def make_DB_plot(all_data, log=True, train=True, error=False):
    ax = plt.gca()
    x = np.hstack((trainsteps, [1e5])).astype(int)
    label_pos = [.79, .645, .38, .165]
    for i, n in enumerate(ns):
        for metric in ["TestError", "TrainError"]:
            y = []
            for run in range(n_trials):
                if run in skip:
                    continue
                exp_name = '{}'.format(n)
                steps, data = all_data["{}_{}".format(exp_name, run)][metric]
                data = data + data[-1:]*(len(x)-len(data))
                y.append(data)
            yerr = np.std(y, axis=0)
            y = np.mean(y, axis=0)
            y_smooth = y.copy()
            y_smooth[1:] = gaussian_filter1d(y_smooth[1:], sigma=1, mode='nearest')
            yerr_smooth = gaussian_filter1d(yerr, sigma=1, mode='nearest')

            color = get_plot_color(i, len(ns))
            if metric == "TestError":
                plt.plot(x, y_smooth, '-', label=n, color=color, alpha=0.85)
                if error:
                    ax.fill_between(x, y_smooth-yerr_smooth, y_smooth+yerr_smooth, color=color, alpha=0.13)
            else:
                if train:
                    plt.plot(x, y_smooth, ':', color=color)
                    if error:
                        ax.fill_between(x, y_smooth-yerr_smooth, y_smooth+yerr_smooth, color=color, alpha=0.13)
        
        ax.text(0.99, label_pos[i], r'$n={}$'.format(n), color=color, ha='right', transform=ax.transAxes, fontsize=14)
    
    ax.text(0.01, .92, "A", color='black', transform=ax.transAxes, fontsize=15, fontweight='bold')
    if log:
        plt.xscale('log')
        plt.yscale('log')
        if train:
            plt.ylim(7e-2, 1.2e0)
        else:
            plt.ylim(7e-2, 1e0)
        plt.xlim(1e1, 6e4)
    else:
        plt.ylim(0, 1e0)
        plt.xlim(0, 1.3e4)

    plt.legend(
        (
            Line2D([0], [0], color='xkcd:navy', alpha=.65, ls=":"),
            Line2D([0], [0], color='xkcd:navy', alpha=.65, ls="-"),),
        (
            "NN train",
            "NN test",),
        framealpha=0.95)
    plt.xlabel('Train steps')
    plt.ylabel("Classification Error")
    plt.tight_layout()
    plt.axhline(0.9, ls='--', color='black', lw=1, alpha=0.4)

    names = [
        "train" if train else "test",
        "err" if error else "noerr",
        "log" if log else "lin"
    ]
    fname = "_".join(names)
    # plt.show()
    plt.savefig("NN_deep_bootstrap.pdf", bbox_inches='tight')

make_DB_plot(all_data)