In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch

from einops import rearrange, reduce, repeat

import sys
import os
from tqdm import tqdm

sys.path.append("../")

In [2]:
from ImageData import ImageData, preprocess
from ExptTrace import ExptTrace
from FileManager import FileManager
from kernels import GaussianKernel, LaplaceKernel, ExponentialKernel, krr
from feature_decomp import generate_fra_monomials
from utils import ensure_torch, ensure_numpy
from data import get_matrix_hermites, get_gaussian_data, get_hermite_target, get_powerlaw
from eigenlearning import eigenlearning

In [3]:
# TOP LEVEL HYPERPARAMS
#######################

EXPT_NAME = "hehe-eigenlearning"
DATASET = "gaussian"

N_SAMPLES = 20_000
P_MODES = 10_000
DATA_DIM = 200

# SETUP FILE MANAGEMENT
#######################

datapath = os.getenv("DATASETPATH")
exptpath = os.getenv("EXPTPATH")
if datapath is None:
    raise ValueError("must set $DATASETPATH environment variable")
if exptpath is None:
    raise ValueError("must set $EXPTPATH environment variable")
expt_dir = os.path.join(exptpath, "phlab", EXPT_NAME, DATASET)

if not os.path.exists(expt_dir):
    os.makedirs(expt_dir)
expt_fm = FileManager(expt_dir)

In [4]:
def emp_eigsys(kernel, y):
    eigvals, eigvecs = kernel.eigendecomp()
    eigcoeffs = eigvecs.T @ y
    eigcoeffs /= torch.linalg.norm(eigcoeffs)
    eigvals = eigvals.cpu().numpy()
    eigcoeffs = eigcoeffs.cpu().numpy()

    return eigvals, eigcoeffs


def fra_eigsys(X, y, eval_level_coeff):
    N, _ = X.shape
    S = torch.linalg.svdvals(X)
    data_eigvals = S**2 / (S**2).sum()

    eigvals, monomials = generate_fra_monomials(data_eigvals, N, eval_level_coeff, kmax=10)
    H = get_matrix_hermites(X, monomials)
    eigcoeffs = torch.linalg.lstsq(H, y).solution
    eigcoeffs /= torch.linalg.norm(eigcoeffs)
    eigcoeffs = eigcoeffs.cpu().numpy()

    return eigvals, eigcoeffs


def learning_curve(ntrains, eigvals, eigcoeffs, ridge=0, noise_var=0):
    kappas, learnabilities, e0s, train_mses, test_mses = [np.zeros(len(ntrains)) for _ in range(5)]
    for i, n in enumerate(ntrains):
        res = eigenlearning(n, eigvals, eigcoeffs, ridge, noise_var)
        # kappas[i] = res["kappa"]
        # learnabilities[i] = res["learnability"]
        # e0s[i] = res["overfitting_coeff"]
        train_mses[i] = res["train_mse"]
        test_mses[i] = res["test_mse"]
    return train_mses, test_mses

In [5]:
data_eigval_exps = np.linspace(1., 2., num=3)
zca_strengths = [0, 5e-3, 3e-2]
kerneltypes = [GaussianKernel, LaplaceKernel]
kernel_widths = [1, 4]

var_axes = ["d_eff", "kernel", "kernel_width"]
et_pathnames, et_emp_eigvals, et_fra_eigvals = ExptTrace.multi_init(3, var_axes)

if DATASET == "cifar10":
    data_dir = os.path.join(datapath, "cifar10")
    cifar10 = ImageData('cifar10', data_dir, classes=None)
    X_raw, _ = cifar10.get_dataset(N_SAMPLES, get="train")
if DATASET == "imagenet32":
    fn = os.path.join(datapath, "imagenet", f"{DATASET}.npz")
    data = np.load(fn)
    X_raw = data['data'][:N_SAMPLES].astype(float)
    X_raw = rearrange(X_raw, 'n (c h w) -> n c h w', c=3, h=32, w=32)


# Gaussian data

In [14]:
data_eigval_exp = data_eigval_exps[0]
data_eigvals = get_powerlaw(DATA_DIM, data_eigval_exp, offset=6)
X = get_gaussian_data(N_SAMPLES, data_eigvals)

d_eff = 1/(data_eigvals**2).sum()

# create kernel
print("Creating kernel")

kerneltype = GaussianKernel
kwidth = 4
kernel = kerneltype(X, kernel_width=kwidth)

# Generate synthetic target
print("Generating synthetic target")

eval_level_coeff = kerneltype.get_level_coeff_fn(data_eigvals, kernel_width=kwidth)
eigvals, monomials = generate_fra_monomials(data_eigvals, P_MODES, eval_level_coeff)
H = get_matrix_hermites(X, monomials)

beta = 1.1
noise_var = 2e-1
squared_coeffs = get_powerlaw(P_MODES, beta, offset=6)
y, snr = get_hermite_target(H, squared_coeffs, noise_var=noise_var)

print(f"SNR = {snr}")

Creating kernel
Generating synthetic target
SNR = 5.055937767028809


In [15]:
ntrains = np.logspace(1, 4, base=10, num=20).astype(int)
et_test_mse = ExptTrace(["trial", "n"])
et_train_mse = ExptTrace(["trial", "n"])
ystar_idx = 5
ridge = 1e-3
ntrials = 5

K = ensure_torch(kernel.K)

for trial in tqdm(range(ntrials)):
    for ntrain in ntrains:
        train_mse, test_mse, yhattest = krr(K, y, ntrain, n_test=2000, ridge=ridge)
        et_test_mse[trial, ntrain] = test_mse
        et_train_mse[trial, ntrain] = train_mse


100%|██████████| 5/5 [00:13<00:00,  2.63s/it]


# Cifar 10

In [None]:
classes = [[0], [1]]
NTRAIN = 10_000
NTEST = 2_000

cifar10 = ImageData('cifar10', DATA_PATH, classes=classes)
X_train, y_train = cifar10.get_dataset(NTRAIN, get="train")
X_test, y_test = cifar10.get_dataset(NTEST, get="test")
X_train, y_train, X_test, y_test = [torch.Tensor(t) for t in (X_train, y_train, X_test, y_test)]
X_raw = torch.cat([X_train, X_test])
y = torch.cat([y_train, y_test])
y -= y.mean()
y = y[:,0]-y[:,1]

X = preprocess(X_raw, center=True, grayscale=True, zca_strength=0)
X = to_torch(X)

N, _ = X.shape
S = torch.linalg.svdvals(X)
# to make norm(x)~1 on average
X *= torch.sqrt(N / (S**2).sum())
data_eigvals = S**2 / (S**2).sum()

print(f"d_eff={1/(data_eigvals**2).sum().item():.2f}")

Files already downloaded and verified
Files already downloaded and verified
d_eff=6.60


In [13]:
X = preprocess(X_raw, center=True, grayscale=True, zca_strength=.02)
X = to_torch(X)
S = torch.linalg.svdvals(X)
print(f"d_eff={(S**2).sum().item()**2/(S**4).sum().item():.2f}")

d_eff=38.20
