In [None]:
import gc

from scipy.interpolate import interp1d
import jax
import jax.numpy as jnp
import neural_tangents as nt
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from jax import grad, jit, jacfwd, jacrev, lax, random, vmap
from jax.example_libraries import optimizers
from neural_tangents import stax
from tqdm import tqdm

In [None]:
cmap = matplotlib.colormaps.get_cmap('tab20')

In [None]:
D = 100
P = 550
N = 500

def target_fn(beta, X):
        return (X.T @ beta)**2/2.0

X = random.normal(random.PRNGKey(0), (D,P))/ jnp.sqrt(D)
Xt = random.normal(random.PRNGKey(1), (D,1000))/ jnp.sqrt(D)
beta = random.normal(random.PRNGKey(2), (D,))

y = target_fn(beta, X)
yt = target_fn(beta,Xt)

In [None]:
print(X.shape, y.shape)
print(Xt.shape, yt.shape)

In [None]:
W = random.normal(random.PRNGKey(0), (N, D))
a = random.normal(random.PRNGKey(0), (N, ))
params = [a, W]
alpha = 1 # scaling parameter, NOT weight norm scale
eps = 0.02

def NN_func2(params,X):
    global alpha
    global eps

    a, W = params
    D = W.shape[1]
    N = a.shape[0]
    h = W @ X.T

    f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts
    return f


def phi(z, eps = 0.25):
        return z + 0.5*eps*z**2

In [None]:
W = random.normal(random.PRNGKey(0), (N, D))
a = random.normal(random.PRNGKey(0), (N, ))
params = [a, W]
alpha = 1 # scaling parameter, NOT weight norm scale
eps = 0.02

def NN_func2(params,X):
    global alpha
    global eps

    a, W = params
    D = W.shape[1]
    N = a.shape[0]
    h = W @ X.T

    f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts
    return f


def phi(z, eps = 0.25):
        return z + 0.5*eps*z**2

In [None]:
ntk_fn = nt.empirical_ntk_fn(
    NN_func2, vmap_axes=0, trace_axes=())
#, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)

def kernel_regression(X, y, Xt, yt, params, which='test'):
      K_train = ntk_fn(X.T, None, params)

      a = jnp.linalg.solve(K_train, y)

      def estimate(xt):
        k_test_train = ntk_fn(Xt.T, X.T, params)
        k_test_train_squeezed = jnp.squeeze(k_test_train)
        return jnp.dot(k_test_train_squeezed, a)

      estimates = vmap(estimate)(Xt.T if which=='test' else X.T)
      labels = yt if which=='test' else y
      mse = jnp.mean((estimates - labels) ** 2)
      return mse


def kalignment(K, train_y):
    train_yc = train_y.reshape(-1, 1)
    train_yc = train_yc - train_yc.mean(axis=0)
    Kc = K - K.mean(axis=0)
    top = jnp.dot(jnp.dot(train_yc.T, Kc), train_yc)
    bottom = jnp.linalg.norm(Kc) * (jnp.linalg.norm(train_yc)**2)
    return jnp.trace(top)/bottom

In [None]:
kmse = kernel_regression(X, y, Xt, yt, params)

alphas = [1]
epsilons = [0.02]
epochs = 100000
CENTER_LOSS = True
TRAIN_READOUTS = False
ntk_interval = 100

for alpha in alphas:
    for eps in epsilons:
        kaligns_test = []
        epochs_to_plot = []
        dots = []
        
        Cs, As = [], []
        actual_w1aligns, actual_w2aligns = [], []
        w1_aligns, w2_aligns = [], []
        w1_vars, w2_vars, ws_covs = [], [], []
        vars_compute_interval = 50
        
        lamb = 0
        eta = N/alpha**2
        opt_init, opt_update, get_params = optimizers.sgd(eta)
        opt_init_lin, opt_update_lin, get_params_lin = optimizers.sgd(eta)
        
        opt_state = opt_init(params)
        opt_state_lin = opt_init_lin(params)
        
        f_lin = nt.linearize(NN_func2, params)
        lin_tr_losses = []
        lin_te_losses = []
        
        
        if CENTER_LOSS:
            loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X.T)- NN_func2(params,X.T) - y )**2))
        else:
            loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X) - y )**2 / alpha**2 ))
        
        f_lin0 = nt.linearize(NN_func2, params)
        lin_loss = jit(lambda p, X, y: jnp.mean((f_lin(p, X.T) - f_lin0(params, X.T) - y)**2)  )
        grad_loss_lin = jit(grad(lin_loss, 0))
        
        reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )
        
        grad_loss = jit(grad(reg_loss,0))
        
        tr_losses = []
        te_losses = []
        # check this
        W_all = []
        
        alignments, alignmentst = [], []
        epochs_to_plot = []
        
        t1s, t2s, t3s, epochs_to_compute = [], [], [], []
        t1sm, t2sm, t3sm, ts_summ = [], [], [], []
        ts_sum = []
        alignments, alignmentst = [], []
        
        kmse = kernel_regression(X, y, Xt, yt, get_params(opt_state))
        
        for t in tqdm(range(epochs)):
            opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)
            pars = get_params(opt_state)
            
            train_loss = loss_fn(pars, X, y)
            test_loss = loss_fn(pars, Xt, yt)
            tr_losses += [train_loss]
            te_losses += [test_loss]
            if t % 10000 == 0:
                W_all.append(pars[1])
            
            # new update rule for f_lin to compare learning curves
            lin_pars = get_params_lin(opt_state_lin)
            opt_state_lin = opt_update_lin(t, grad_loss_lin(lin_pars, X, y), opt_state_lin)
            
            lin_tr_losses += [ lin_loss(lin_pars, X, y) ]
            lin_te_losses += [ lin_loss(lin_pars, Xt, yt) ]
        
            if t % vars_compute_interval == 0:
                epochs_to_compute.append(t)
            if t % ntk_interval == 0 and t>0:
                K_test = ntk_fn(Xt.T, None, pars)
                cka_test = kalignment(K_test, yt)
                kaligns_test += [ cka_test ]
            
            
            if t % 5000 == 0 and t>0:
                max_t = t
                t_values = np.arange(0, max_t, ntk_interval)
                interpolator = interp1d(t_values, kaligns_test, kind='linear', fill_value='extrapolate')
                interpolated_kaligns = interpolator(np.arange(max_t))

In [None]:
len(W_all)

In [None]:
pars[1]

In [None]:
fig, ax1 = plt.subplots()

col = cmap(0)
ax1.plot(np.array(tr_losses), linestyle='--', label=rf'Train Loss', color=col, lw=2)
ax1.plot(np.array(te_losses), label=rf'Test Loss', color=col, lw=2)
ax1.plot(np.array(lin_tr_losses), color='black', linestyle='--', label=f'Linearized train loss')
ax1.plot(np.array(lin_te_losses), color='black', label=f'Linearized test loss')
ax1.axhline(kmse, color='r', label=rf'$K_0$ regression MSE')
ax1.set_xlabel('Epochs', fontsize=20)
ax1.set_xscale('log')
ax1.set_ylabel('MSE', fontsize=20)
ax1.legend(loc='lower left', bbox_to_anchor=(0, 0.1))

ax2 = ax1.twinx()
ax2.plot(interpolated_kaligns, linestyle='--', color='green', label=r'NTK alignment, $\frac{y^T K_0y}{||K_0||_F||y||^2}$', lw=2)
ax2.legend(loc='upper left', bbox_to_anchor=(0, 0.58))
ax2.set_ylabel('NTK alignment', fontsize=20)
plt.tight_layout()
plt.show()