In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from tqdm import tqdm
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

plt.rc("figure", dpi=100, facecolor=(1, 1, 1))
plt.rc("font", family='stixgeneral', size=14)
plt.rc("axes", facecolor=(1, .99, .95), titlesize=18)
plt.rc("mathtext", fontset='cm')

import sys
sys.path.append("../")

from kernels import ReluNTK
from feature_decomp import Monomial, generate_hea_monomials
from utils import ensure_torch
from utils import derive_seed, seed_everything
from mlps import MLP, train_network
from tools import trial_count_fn

In [None]:
EXPT_NAME = "mlp-learning-curves"
N_TRAIN = 4_000
N_TEST = 10_000
N_TOT = N_TEST+N_TRAIN
NS = np.logspace(1, 3, 20, dtype=int)
DATASET = "synthetic"
kerneltype = ReluNTK
TARGET_FUNCTION_TYPE = "monomial" # just powerlaw for now
ONLYTHRESHOLDS = True # if True, only record last loss instead of full curve

SEED = 42

# MLP HPs
LR = 1e-2 #   base 1e-2
DEPTH = 3 #   base 3
WIDTH = 1024 #base 1024
GAMMA = 1 #   base 1
NS=np.array([1024]) # base 1024

# Training HPs
ONLINE = True
LOSS_CHECKPOINTS = (0.1, 1e-12) #needs a len atm
MAX_ITER = int(1e20)
EMA_SMOOTHER = 0.9
DETERMINSITIC = True
trial_counts = np.array([trial_count_fn(n) for n in NS], dtype=int)
# max_trials   = int(trial_counts.max())
if ONLINE:
    NS=np.array([1024])
    trial_counts = np.ones_like(NS)*3
    LOSS_CHECKPOINTS = (0.15, 0.1)

global_config = dict(DEPTH=DEPTH, WIDTH=WIDTH, LR=LR, GAMMA=GAMMA,
    EMA_SMOOTHER=EMA_SMOOTHER, MAX_ITER=MAX_ITER,
    LOSS_CHECKPOINTS=LOSS_CHECKPOINTS, N_TEST=N_TEST,
    SEED=SEED, ONLYTHRESHOLDS=ONLYTHRESHOLDS,
)

# Dataset HPs
# Note: not all of these are used for all datasets
datasethps = {"normalized": True,
              "cutoff_mode": 20_000,
              "d": 200,
              "offset": 6,
              "alpha": 1.7, #1.14~ is cifar
              "noise_size": 1,
              "yoffset": 1.2,
              "beta": 1.2,
              "classes": None,
              "binarize": False,
              "weight_variance": 1,
              "bias_variance": 1,
              "kmax":6,
              }

In [None]:
def run_trial(job, global_config, bfn_config=None, jobid=None):
    base_seed = global_config.get("SEED", None)
    job_seed  = derive_seed(base_seed, jobid)
    GEN, RNG = seed_everything(job_seed, jobid)

    torch.set_num_threads(1)  # avoid CPU contention when many procs

    if TARGET_FUNCTION_TYPE == "monomial":
        from data import polynomial_batch_fn
        batch_function = lambda target_monomial, n, X, y: polynomial_batch_fn(**bfn_config, monomials=target_monomial, bsz=n, gen=GEN, X=X, y=y)
    target, n, trial = job
    X_te, y_te = batch_function(target, global_config["N_TEST"], X=None, y=None)(0)

    X_tr, y_tr = batch_function(target, n, X=None, y=None)(trial) if not ONLINE else None, None

    bfn = batch_function(target, n, X=X_tr, y=y_tr)

    model = MLP(d_in=global_config["DIM"], depth=global_config["DEPTH"],
                d_out=1, width=global_config["WIDTH"]).to(device)

    outdict = train_network(
        model=model,
        batch_function=bfn,
        lr=global_config["LR"],
        max_iter=global_config["MAX_ITER"],
        loss_checkpoints=global_config["LOSS_CHECKPOINTS"],
        gamma=global_config["GAMMA"],
        ema_smoother=global_config["EMA_SMOOTHER"],
        only_thresholds=global_config["ONLYTHRESHOLDS"],
        X_tr=X_tr, y_tr=y_tr,
        X_te=X_te, y_te=y_te,
        verbose=False,
    )

    timekeys = outdict["timekeys"]
    train_losses = outdict["train_losses"]
    test_losses = outdict["test_losses"]

    # Cleanup GPU memory
    del outdict, model, X_tr, y_tr, X_te, y_te
    torch.cuda.empty_cache()
    gc.collect()

    return (n, str(target), int(trial), train_losses, test_losses, timekeys)

In [None]:
def select_indices_with_geometric_decay(values, ratio=.9):

  # assert that values is (a) positive and (b) already sorted
  assert np.all(values > 0)
  assert np.all(np.diff(values) <= 0)

  selected_indices = []

  cur_eigval_thresh = values[0] + 1
  ratio = .9

  for i in range(len(hea_eigvals)):
    if values[i] < cur_eigval_thresh:
      selected_indices.append(i)
      cur_eigval_thresh = values[i] * ratio

  return selected_indices

from data import get_powerlaw
data_eigvals = get_powerlaw(P=datasethps['d'], exp=datasethps['alpha'], offset=datasethps['offset'], normalize=True) #aka data_eigvals
level_coeff_fn = ReluNTK.get_level_coeff_fn(data_eigvals=data_eigvals, bias_variance=1, weight_variance=1)
hea_eigvals, monomials = generate_hea_monomials(data_eigvals, datasethps['cutoff_mode'], level_coeff_fn, kmax=6) #don't touch kmax, not going above order 6

data_indices_of_interest = [0, 1, 2, 3, 5, 10, 20, 40, 60, 100, 150]
gammas_of_interest = data_eigvals.cpu().numpy()[data_indices_of_interest]
hea_eigvals, monomials = generate_hea_monomials(gammas_of_interest, datasethps['cutoff_mode'], level_coeff_fn, kmax=4)

selected_indices = select_indices_with_geometric_decay(hea_eigvals, .9)
hea_eigenval_cutoff = 1e-6
selected_indices = [i for i in selected_indices if hea_eigvals[i] > hea_eigenval_cutoff]

selected_hea_eigvals = hea_eigvals[selected_indices]
selected_monomials = [monomials[i] for i in selected_indices]

# f"selected {len(selected_indices)} HEA eigenmodes."
monomials_as_dicts = [monomial.basis() for monomial in monomials]
mapped_monomials_as_dicts = []
for monomial_dict in monomials_as_dicts:
    mapped_dict = {}
    for key, value in monomial_dict.items():
        mapped_key = data_indices_of_interest[key]
        mapped_dict[mapped_key] = value
    mapped_monomials_as_dicts.append(mapped_dict)

In [None]:
## --- Grab targets and such ---
if DATASET == "synthetic":
    from notebook_fns import get_synthetic_dataset

    X_full, _, H, monomials, hea_eigvals, _, data_eigvals = get_synthetic_dataset(**datasethps, N=N_TOT, kerneltype=kerneltype,
                                                                                        gen=torch.Generator(device='cuda').manual_seed(SEED))

elif DATASET == "cifar10":
    from imdata import ImageData
    PIXEL_NORMALIZED =  False
    classes = datasethps['classes']
    normalized = datasethps['normalized']

    if classes is not None:
        imdata = ImageData('cifar10', "../data", classes=classes, onehot=len(classes)!=2, format="N")
    else:
        imdata = ImageData('cifar10', "../data", classes=classes, onehot=False, format="N")
    X_train, y_train = imdata.get_dataset(N_TRAIN, **datasethps, get='train',
                                        centered=True, normalize=PIXEL_NORMALIZED)
    X_test, y_test = imdata.get_dataset(N_TEST, **datasethps, get='test',
                                        centered=True, normalize=PIXEL_NORMALIZED)
    X_train, y_train, X_test, y_test = map(ensure_torch, (X_train, y_train, X_test, y_test))
    y_train = y_train.squeeze()
    y_test = y_test.squeeze()
    X_train, y_train, X_test, y_test = [t/torch.linalg.norm(t) for t in (X_train, y_train, X_test, y_test)] if normalized else (X_train, y_train, X_test, y_test)
    if normalized:
        X_train *= N_TRAIN**(0.5); X_test *= N_TEST**(0.5)
        y_train *= N_TRAIN**(0.5); y_test *= N_TEST**(0.5)
    X_full = torch.cat((X_train, X_test), dim=0)
    y_full = torch.cat((y_train, y_test), dim=0)
    data_eigvals = torch.linalg.svdvals(X_full)**2
    data_eigvals /= data_eigvals.sum()


U, lambdas, Vt = torch.linalg.svd(X_full, full_matrices=False)
dim = X_full.shape[1]

## --- Target function defs ---
if TARGET_FUNCTION_TYPE == "monomial":
    target_monomials = [Monomial(mmd) for mmd in mapped_monomials_as_dicts]
    targets = target_monomials
    bfn_config = dict(lambdas=lambdas, Vt=Vt, data_eigvals=data_eigvals, N=N_TOT)

global_config.update(dict(DIM=dim))

jobs = [(target, n, trial)
        for target in targets
        for nidx, n in enumerate(NS)
        for trial in range(int(trial_counts[nidx]))]

total = len(jobs)
done = 0
print(f"Training base case")
with tqdm(total=total, desc="Runs", dynamic_ncols=True) as pbar:
    while done < total:
        job = jobs[done]
        target = job[0]
        n, tstr, trial, train_losses, test_losses, timekeys = run_trial(job, global_config, bfn_config, done)
        if not(ONLYTHRESHOLDS):
            train_losses = train_losses[-1]
            test_losses = test_losses[-1]
        print("")
        pbar.set_postfix_str(
        f"train {train_losses:.3g} | test {test_losses:.3g} | timekey {timekeys} | n={n} | target={tstr} | trial={trial}",
        refresh=False)

        done += 1
        pbar.update(1)

torch.cuda.empty_cache()