In [1]:
%env CUDA_VISIBLE_DEVICES=7
import jax
from jax import numpy as jnp, vmap, jit, random, lax, value_and_grad
from jax.numpy import linalg as jla
from jax.flatten_util import ravel_pytree
from neural_tangents import taylor_expand

from util import fold, laxmap

from matplotlib import pyplot as plt
import numpy as np
import pickle
import scipy

import numpy as np
import torchvision
import torch

from flax import linen as nn
from typing import Callable

from flax import linen as nn
from flax.linen import initializers as jinit
from functools import partial

env: CUDA_VISIBLE_DEVICES=7


In [2]:
# this is default PyTorch init for ReLU
torch_init = jinit.variance_scaling(1 / 2, "fan_in", "uniform")

TorchLinear = partial(
    nn.Dense, kernel_init=torch_init, bias_init=jinit.zeros, dtype=None
)
TorchConv = partial(nn.Conv, kernel_init=torch_init, bias_init=jinit.zeros, dtype=None)


class CNN(nn.Module):
    sigma: Callable

    @nn.compact
    def __call__(self, x):
        x = TorchConv(512, (3, 3))(x)
        x = self.sigma(x)
        x = nn.avg_pool(x, (2, 2), (2, 2), "SAME")
        x = TorchConv(512, (3, 3))(x)
        x = self.sigma(x)
        x = nn.avg_pool(x, (2, 2), (2, 2), "SAME")
        x = TorchConv(512, (3, 3))(x)
        x = self.sigma(x)
        x = nn.avg_pool(x, (2, 2), (2, 2), "SAME")
        x = TorchConv(512, (3, 3))(x)
        x = self.sigma(x)
        x = nn.avg_pool(x, (2, 2), (2, 2), "SAME")
        x = x.reshape(x.shape[0], -1)
        x = TorchLinear(1)(x)[...,0]
        return x

In [None]:
DATA_DIR = "~/datasets/"

traindata = torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True)
testdata = torchvision.datasets.CIFAR10(DATA_DIR, train=False, download=True)


d = 3*32*32

class0 = 3
class1 = 7

train_x = np.array(traindata.data)
train_y = np.array(traindata.targets)

idx = np.logical_or(train_y == class0, train_y == class1)

train_x = train_x - np.array([125.30691805, 122.95039414, 113.86538318])
train_x = train_x / np.array([62.99321928, 62.08870764, 66.70489964])
train_x = train_x[idx]
train_y = train_y[idx]
train_y = np.array(train_y == class0, dtype=int)*2 - 1


test_x = np.array(testdata.data)
test_y = np.array(testdata.targets)

idx = np.logical_or(test_y == class0, test_y == class1)
test_x = test_x - np.array([125.30691805, 122.95039414, 113.86538318])
test_x = test_x / np.array([62.99321928, 62.08870764, 66.70489964])
test_x = test_x[idx]
test_y = test_y[idx]
test_y = np.array(test_y == class0, dtype=int)*2 - 1

print(train_x.shape)
print(test_x.shape)

In [None]:
n_train = 10000
n_test = 2000
n = n_train
m = 100
T = 10000
lr = 5e-3
print(d, n_train, m)

seed = 1

X_train = train_x
y_train = train_y

X_test = test_x
y_test = test_y

In [None]:
res_tot = {}
model_rng, train_rng, test_rng, fn_rng, key = random.split(random.PRNGKey(seed),5)

sigma = lambda z: jax.nn.relu(z)

model = CNN(sigma=sigma)
init_params = model.init(model_rng,train_x[:1])
init_params, unravel = ravel_pytree(init_params)
f = lambda p,x: model.apply(unravel(p),x)

@jit
def full(x, W):
    return f(W, x)

def lin_plus_quad(x, W):
    return taylor_expand(lambda params: f(params, x), init_params, 2)(W)

def linearize(x, W):
    return taylor_expand(lambda params: f(params, x), init_params, 1)(W)

def ntk_feature(x):
    return jax.grad(lambda params: f(params, x))(init_params)


full_loss_fn = lambda W: jnp.mean((y_train - full(X_train, W))**2)
full_test_loss_fn = lambda W: jnp.mean((y_test - full(X_test, W))**2)

quad_loss_fn = lambda W: jnp.mean((y_train - lin_plus_quad(X_train, W))**2)
quad_test_loss_fn = lambda W: jnp.mean((y_test - lin_plus_quad(X_test, W))**2)

lin_loss_fn = lambda W: jnp.mean((y_train - linearize(X_train, W))**2)
lin_test_loss_fn = lambda W: jnp.mean((y_test - linearize(X_test, W))**2)    

num_epochs = 100
batch_size = 64

@jit
def full_step(input):

    params, key = input
    full_loss = 0.
    T = int(n_train/batch_size)
    for it in range(T):
        X_batch = X_train[it*batch_size:(it+1)*batch_size]
        y_batch = y_train[it*batch_size:(it+1)*batch_size]

        loss_fn = lambda W: jnp.mean((y_batch - full(X_batch, W))**2)
    
        loss = loss_fn(params)
        full_loss = full_loss + loss
        grads = jax.grad(loss_fn)(params)
        params = params-lr*grads

    test_loss = full_test_loss_fn(params)
    quad_only = lin_plus_quad(X_test, params) - linearize(X_test, params)
    quad_corr = quad_test_loss_fn(params)
    lin_corr = lin_test_loss_fn(params)

    test_acc = jnp.sum(full(X_test, params)*y_test > 0)/2000.
    lin_acc = jnp.sum(linearize(X_test, params)*y_test > 0)/2000.
    quad_acc = jnp.sum((lin_plus_quad(X_test, params) - linearize(X_test, params))*y_test > 0)/2000.

    return dict(state=(params, key),save=(full_loss/T, test_loss, test_acc, lin_acc, quad_acc))

@jit
def lin_step(input):

    params, key = input
    full_loss = 0.
    T = int(n_train/batch_size)
    for it in range(T):
        X_batch = X_train[it*batch_size:(it+1)*batch_size]
        y_batch = y_train[it*batch_size:(it+1)*batch_size]

        loss_fn = lambda W: jnp.mean((y_batch - linearize(X_batch, W))**2)
    
        loss = loss_fn(params)
        full_loss = full_loss + loss
        grads = jax.grad(loss_fn)(params)

        params = params-lr*grads

    test_loss = jnp.mean((y_test - linearize(X_test, params))**2)    


    return dict(state=(params, key),save=(full_loss/T, test_loss))

@jit
def quad_step(input):
    params, key = input
    full_loss=0.
    T = int(n_train/batch_size)
    for it in range(T):
        X_batch = X_train[it*batch_size:(it+1)*batch_size]
        y_batch = y_train[it*batch_size:(it+1)*batch_size]

        loss_fn = lambda W: jnp.mean((y_batch - lin_plus_quad(X_batch, W))**2)
    
        loss = loss_fn(params)
        full_loss = full_loss + loss
        grads = jax.grad(loss_fn)(params)

        params = params-lr*grads

    test_loss = jnp.mean((y_test - lin_plus_quad(X_test, params))**2)    

    return dict(state=(params, key),save=(full_loss/T, test_loss))


params = init_params

full_res = {}
full_res = fold(lambda input: full_step(input),(params, key),steps=num_epochs,show_progress=True)
filename = 'full_data.npy'
pickle.dump(full_res, open(filename, 'wb'))

lin_res = {}
lin_res = fold(lambda input: lin_step(input),(params, key),steps=num_epochs,show_progress=True)
filename = 'lin_data.npy'
pickle.dump(lin_res, open(filename, 'wb'))

quad_res = {}
quad_res = fold(lambda input: quad_step(input),(params, key),steps=num_epochs,show_progress=True)
filename = 'quad_data.npy'
pickle.dump(quad_res, open(filename, 'wb'))


In [None]:
p = full_res['state'][0]
np.sum((full(X_test, p))*y_test > 0)/2000.
np.sum((linearize(X_test, p))*y_test > 0)/2000.
np.sum((lin_plus_quad(X_test, p))*y_test > 0)/2000.

In [None]:
print(full_test_loss_fn(p))
print(lin_test_loss_fn(p))
print(quad_test_loss_fn(p))

In [None]:
plt.rcParams['font.size'] = '16'

fig,axs = plt.subplots(1, 3,figsize=(15, 5))
axs = np.ravel(axs)

colors = plt.get_cmap('hsv')(np.linspace(0, 1, 3))

plt.sca(axs[0])
plt.ylabel("Train Loss")
plt.plot(full_res['save'][0], color = 'orange', label=r'$f$')
plt.plot(quad_res['save'][0], color = 'blue',label=r'$f_L + f_Q$')
plt.plot(lin_res['save'][0], color = 'purple',label=r'$f_L$')
plt.xlabel("epochs")
plt.ylim(ymin=0.)
plt.legend()

plt.sca(axs[1])
plt.ylabel("Test Loss")
plt.plot(full_res['save'][1], color = 'orange',label=r'$f$')
plt.plot(quad_res['save'][1], color = 'blue',label=r'$f_L + f_Q$')
plt.plot(lin_res['save'][1], color = 'purple',label=r'$f_L$')
plt.xlabel("epochs")
plt.legend()

plt.tight_layout()