In [None]:
import os

if not os.path.isdir("cld_optimization_experiments"):
    !git clone https://github.com/oopir/cld_optimization_experiments

%cd cld_optimization_experiments

In [None]:
import numpy as np
import random
import torch

from src.data import load_digits_data
from src.training import train_multiseed
from src.plots import plot_ex1_multiseed

torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
data_seed = 0
run_seeds = list(range(5))

In [None]:
n = 40
m = 30000
clean_beta = n*1e05
noisy_beta = n*1e03

In [None]:
epochs = 200
track_every = max(1,epochs//100)
print_every = epochs//5

In [None]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

data = load_digits_data(n=n, random_labels=False, device=device, seed=0)
d = data["X_train"].shape[1]

common = dict(
    seeds=run_seeds,
    data=data,
    eta=1e-5,
    epochs=epochs,
    lam_fc1=d / (torch.nn.init.calculate_gain("tanh") ** 2),
    lam_fc2=m,
    hidden_width=m,
    regularization_scale=1.0,
    track_jacobian=False, 
    use_linearized=True, # for now plotting has bugs when this is False
    device=device,
    track_every=track_every,
    print_every=print_every,
)

results = {
    "clean": train_multiseed(beta=clean_beta, **common,),
    "noisy": train_multiseed(beta=noisy_beta, **common,),
}

In [None]:
plot_ex1_multiseed(results)