In [1]:
%env CUDA_VISIBLE_DEVICES=7
import jax
from jax import numpy as jnp, vmap, jit, random, lax, value_and_grad
from jax.numpy import linalg as jla
from jax.flatten_util import ravel_pytree
from neural_tangents import taylor_expand

from util import fold, laxmap

from matplotlib import pyplot as plt
import numpy as np
import pickle

env: CUDA_VISIBLE_DEVICES=7


In [2]:
d = 100
n = int(d**1.5)
m = 10000
T = 20000
lr = 1.
print(d, n, m)

100 1000 10000


In [3]:
seeds = [1, 11, 111, 1111, 11111]
lambs = [0.0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]

In [4]:
for seed in seeds:
    res_tot = {}
    model_rng, train_rng, test_rng, fn_rng, key = random.split(random.PRNGKey(seed),5)

    W_0 = random.normal(model_rng, (d, int(m/2)))
    W_0 = W_0/jla.norm(W_0, axis=0)
    W_0 = jnp.concatenate([W_0, W_0], axis=1)
    a = jnp.concatenate([jnp.ones(int(m/2)), -jnp.ones(int(m/2))])/jnp.sqrt(m)
    sigma = lambda z: jax.nn.sigmoid(z - 1)

    @jit
    def net(x, W):
        W_mat = W.reshape(d, m)
        return a @ sigma((W_0 + W_mat).T @ x)

    @jit
    def lin_plus_quad(x, W):
        return taylor_expand(lambda params: net(x, params), jnp.zeros(d*m), 2)(W)

    def linearize(x, W):
        return taylor_expand(lambda params: net(x, params), jnp.zeros(d*m), 1)(W)

    def ntk_feature(x):
        return jax.grad(lambda params: net(x, params))(jnp.zeros(d*m))

    # generate low-rank quad + lin
    beta = random.normal(fn_rng, (d,))
    beta = beta/jla.norm(beta)
    def f_star(x):
        return ((beta @ x)**2 - 1 + (beta @ x))/jnp.sqrt(3)

    X_train = random.normal(train_rng, (d, n))
    y_train = vmap(f_star)(X_train.T)
    loss_fn = lambda W : jnp.mean((y_train - lin_plus_quad(X_train, W))**2)

    X_test = random.normal(test_rng, (d, 10000))
    y_test = vmap(f_star)(X_test.T)
    test_loss_fn = lambda W : jnp.mean((y_test - lin_plus_quad(X_test, W))**2)

    n_k = d # how many directions to keep?
    # compute Jacobian:
    J = vmap(ntk_feature)(X_train.T)
    U, S, Vh = jla.svd(J, full_matrices=False)

    J_project = U @ jnp.diag(S * jnp.array([1. if i < n_k else 0. for i in range(n)])) @ Vh
    J_large = J - J_project

    linear_fn = lambda W : jla.norm(linearize(X_train, W))**2/n
    reg_fn = lambda W: (J_large @ W).T @ (J_large @ W)/n

    @jit
    def reg_step(input, lamb):
        
        params, key = input
        loss = loss_fn(params)
        
        # # randomized R_1
        # key, subkey = random.split(key, 2)
        # X_sample = random.normal(subkey, (d, 100))
        # r1 = lambda W: jnp.mean(linearize(X_sample, W - Vh[:d,].T @ (Vh[:d,] @ W))**2)
        # r1_val = r1(params)

        grads = jax.grad(loss_fn)(params) + 0.01*jax.grad(reg_fn)(params) + lamb*jax.grad(r1)(params)
        test_loss = test_loss_fn(params)

        reg = reg_fn(params)

        linear = linear_fn(params)
        params = params-lr*grads

        return dict(state=(params, key),save=(loss, test_loss, reg, linear))

    params = jnp.zeros((d*m))

    
    for lamb in lambs:
        res = fold(lambda input: reg_step(input, lamb),(params, key),steps=T,show_progress=True)
        res_tot[lamb] = res['save']
    
    filename = 'seed' + str(seed) + 'data.npy'
    pickle.dump(res_tot, open(filename, 'wb'))
    