In [None]:
import glob
import re

from scipy.interpolate import interp1d
import jax
import jax.numpy as jnp
import neural_tangents as nt
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from jax import grad, jit, jacfwd, jacrev, lax, random, vmap
from jax.example_libraries import optimizers
from neural_tangents import stax
from tqdm import tqdm

In [None]:
cmap = matplotlib.colormaps.get_cmap('tab20')

In [None]:
# 100 dimensions
# trying to find polynomials of degree k = 2
# a nn, would require 100 samples to converge
# a ntk, would require 10000 samples to converge
D = 100 # dimension
P = 550 # number of samples
N = 500 # layer width

def target_fn(beta, X):
        return (X.T @ beta)**2/2.0

X = random.normal(random.PRNGKey(0), (D,P))/ jnp.sqrt(D)
Xt = random.normal(random.PRNGKey(1), (D,1000))/ jnp.sqrt(D)
beta = random.normal(random.PRNGKey(2), (D,))

y = target_fn(beta, X)
yt = target_fn(beta, Xt)

In [None]:
print(X.shape, y.shape)
print(Xt.shape, yt.shape)

In [None]:
W = random.normal(random.PRNGKey(0), (N, D))
a = random.normal(random.PRNGKey(0), (N, ))
print(W.shape, a.shape)

In [None]:
params = [a, W]
alpha = 1 # scaling parameter, NOT weight norm scale
eps = 0.02

def NN_func2(params,X):
    global alpha
    global eps

    a, W = params
    D = W.shape[1]
    N = a.shape[0]
    h = W @ X.T

    f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts
    return f


def phi(z, eps = 0.25):
        return z + 0.5*eps*z**2

In [None]:
ntk_fn = nt.empirical_ntk_fn(
    NN_func2, vmap_axes=0, trace_axes=())

def kernel_regression(X, y, Xt, yt, params, which='test'):
    K_train = ntk_fn(X.T, None, params)
    
    a = jnp.linalg.solve(K_train, y)
    
    def estimate(xt):
        k_test_train = ntk_fn(Xt.T, X.T, params)
        k_test_train_squeezed = jnp.squeeze(k_test_train)
        return jnp.dot(k_test_train_squeezed, a)
    
    estimates = vmap(estimate)(Xt.T if which=='test' else X.T)
    labels = yt if which=='test' else y
    mse = jnp.mean((estimates - labels) ** 2)
    return mse, K_train


def kalignment(K, train_y):
    train_yc = train_y.reshape(-1, 1)
    train_yc = train_yc - train_yc.mean(axis=0)
    Kc = K - K.mean(axis=0)
    top = jnp.dot(jnp.dot(train_yc.T, Kc), train_yc)
    bottom = jnp.linalg.norm(Kc) * (jnp.linalg.norm(train_yc)**2)
    return jnp.trace(top)/bottom


def get_norms(K):
    """
    given a kernel K, returns Frobenius norm, spectral norm and condition number
    """
    eigenval, eigenvec = jax.numpy.linalg.eigh(K)
    frobenius_norm = jnp.sqrt(np.sum(eigenval**2))
    spectral_norm = max(eigenval)
    condition_number = max(eigenval) / min(eigenval)

    return frobenius_norm, spectral_norm, condition_number

In [None]:
_, K_train = kernel_regression(X, y, Xt, yt, params)

In [None]:
# saving k_0 for reference
K_0 = ntk_fn(Xt.T, None, params)
np.save(f"kernels/k_test_0", K_0)

In [None]:
kmse, _ = kernel_regression(X, y, Xt, yt, params)

alphas = [1]
epsilons = [0.02]
epochs = 100000
CENTER_LOSS = False
TRAIN_READOUTS = False
ntk_interval = 100
COMPUTE_NORMS=True

for alpha in alphas:
    for eps in epsilons:
        kaligns_test = []
        epochs_to_plot = []
        dots = []
        
        Cs, As = [], []
        actual_w1aligns, actual_w2aligns = [], []
        w1_aligns, w2_aligns = [], []
        w1_vars, w2_vars, ws_covs = [], [], []
        vars_compute_interval = 50
        
        lamb = 0
        eta = N/alpha**2
        opt_init, opt_update, get_params = optimizers.sgd(eta)
        opt_init_lin, opt_update_lin, get_params_lin = optimizers.sgd(eta)
        
        opt_state = opt_init(params)
        opt_state_lin = opt_init_lin(params)
        
        f_lin = nt.linearize(NN_func2, params)
        lin_tr_losses = []
        lin_te_losses = []
        
        if CENTER_LOSS:
            loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X.T)- NN_func2(params,X.T) - y )**2))
        else:
            loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X.T) - y )**2 ))
        
        f_lin0 = nt.linearize(NN_func2, params)
        lin_loss = jit(lambda p, X, y: jnp.mean((f_lin(p, X.T) - f_lin0(params, X.T) - y)**2)  )
        grad_loss_lin = jit(grad(lin_loss, 0))
        
        reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )
        
        grad_loss = jit(grad(reg_loss,0))
        
        tr_losses = []
        te_losses = []

        # measures of the kernel change
        k_test_frob = []    
        k_test_spec = []        
        k_test_cond = []        
        
        alignments, alignmentst = [], []
        epochs_to_plot = []
        
        t1s, t2s, t3s, epochs_to_compute = [], [], [], []
        t1sm, t2sm, t3sm, ts_summ = [], [], [], []
        ts_sum = []
        alignments, alignmentst = [], []
        
        kmse, _ = kernel_regression(X, y, Xt, yt, get_params(opt_state))
        
        for t in tqdm(range(epochs)):
            opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)
            pars = get_params(opt_state)
            
            train_loss = loss_fn(pars, X, y)
            test_loss = loss_fn(pars, Xt, yt)
            tr_losses += [train_loss]
            te_losses += [test_loss]
            
            # new update rule for f_lin to compare learning curves
            lin_pars = get_params_lin(opt_state_lin)
            opt_state_lin = opt_update_lin(t, grad_loss_lin(lin_pars, X, y), opt_state_lin)
            
            lin_tr_losses += [ lin_loss(lin_pars, X, y) ]
            lin_te_losses += [ lin_loss(lin_pars, Xt, yt) ]
        
            if t % vars_compute_interval == 0:
                epochs_to_compute.append(t)
            if t % ntk_interval == 0 and t > 0:
                K_test = ntk_fn(Xt.T, None, pars)
                #np.save(f"kernels/k_test_{t}", K_test)
                cka_test = kalignment(K_test, yt)
                kaligns_test += [ cka_test ]

                frob_norm, spectral_norm, condition_number = get_norms(K_test)
                k_test_frob += [frob_norm]
                k_test_spec += [spectral_norm]
                k_test_cond += [condition_number]
                
            
            if t % 5000 == 0 and t > 0:
                max_t = t
                t_values = np.arange(0, max_t, ntk_interval)
                interpolator = interp1d(t_values, kaligns_test, kind='linear', fill_value='extrapolate')
                interpolated_kaligns = interpolator(np.arange(max_t))

                # interpolate norms
                interpolator_spec = interp1d(t_values, k_test_spec, kind='linear', fill_value='extrapolate')
                interpolated_spec = interpolator_spec(np.arange(max_t))

                interpolator_frob = interp1d(t_values, k_test_frob, kind='linear', fill_value='extrapolate')
                interpolated_frob = interpolator_frob(np.arange(max_t))

                interpolator_cond = interp1d(t_values, k_test_cond, kind='linear', fill_value='extrapolate')
                interpolated_cond = interpolator_cond(np.arange(max_t))
                
                # save weights
                np.save(f"weights/w_{t}", pars[1])
                np.save(f"weights/a_{t}", pars[0])

# Kernel evolution with respect with $K_0$

In [None]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

In [None]:
kernels_list = glob.glob("kernels/k_test_*.npy")
kernels_list.sort(key=natural_keys)

def frobenius(x):
    eigenval, eigenvec = jax.numpy.linalg.eigh(x)
    frobenius_norm = jnp.sqrt(np.sum(eigenval**2))
    return frobenius_norm

k_0 = np.load(kernels_list[0])
k_test = kernels_list
indexes = [k.split("_")[-1][:-4] for k in kernels_list]

all_fb = []
for k in k_test:
    k_t = np.load(k)
    fb = jnp.linalg.norm(k_t-k_0)
    all_fb.append(fb)
    #print(f"k0 - {k.split('/')[-1]}", fb)

fig, ax = plt.subplots(1,1)
ax.plot(all_fb)
ax.set_ylabel(r"Frobenius norm $\|K_t - K_0\|_F$")
ax.set_xlabel("Epochs")
ax.set_xticks(range(0,len(indexes[::100])*100,100), indexes[::100], rotation=45)
plt.savefig("figs/kernel_evolution.png", bbox_inches="tight")

# Plots

## Train Losses and NTK alignment

In [None]:
fig, ax1 = plt.subplots()

col = cmap(0)
ax1.plot(np.array(tr_losses), linestyle='--', label=rf'Train Loss', color=col, lw=2)
ax1.plot(np.array(te_losses), label=rf'Test Loss', color=col, lw=2)
ax1.plot(np.array(lin_tr_losses), color='black', linestyle='--', label=f'Linearized train loss')
ax1.plot(np.array(lin_te_losses), color='black', label=f'Linearized test loss')

ax1.axhline(kmse, color='r', label=rf'$K_0$ regression MSE')
ax1.set_xlabel('Epochs', fontsize=20)
ax1.set_xscale('log')
ax1.set_ylabel('MSE', fontsize=20)
ax1.legend(loc='lower left', bbox_to_anchor=(0, 0.05))

ax2 = ax1.twinx()
ax2.plot(interpolated_kaligns, linestyle='--', color='green', label=r'NTK alignment, $\frac{y^T K_0y}{||K_0||_F||y||^2}$', lw=2)
ax2.legend(loc='upper left', bbox_to_anchor=(0, 0.5))
ax2.set_ylabel('NTK alignment', fontsize=20)
plt.tight_layout()
plt.savefig("figures/main_metrics.png", bbox_inches="tight")
plt.show()

## MSE vs Spectral Norm

In [None]:
fig, ax1 = plt.subplots()

col = cmap(0)
ax1.plot(np.array(tr_losses), linestyle='--', label=rf'Train Loss', color=col, lw=2)
ax1.plot(np.array(te_losses), label=rf'Test Loss', color=col, lw=2)

ax2 = ax1.twinx()
ax2.plot(interpolated_spec, color='red', label=f'Spectral Norm')
ax2.legend(loc='upper left', bbox_to_anchor=(0, 0.45))
ax2.set_ylabel('Spectral Norm', fontsize=20)

ax1.set_xlabel('Epochs', fontsize=20)
ax1.set_xscale('log')
ax1.set_ylabel('MSE', fontsize=20)
ax1.legend(loc='lower left', bbox_to_anchor=(0, 0.20))
plt.tight_layout()
plt.savefig("figures/loss_vs_spectral_norm")
plt.show()

## MSE vs Frobenius Norm

In [None]:
fig, ax1 = plt.subplots()

col = cmap(0)
ax1.plot(np.array(tr_losses), linestyle='--', label=rf'Train Loss', color=col, lw=2)
ax1.plot(np.array(te_losses), label=rf'Test Loss', color=col, lw=2)


ax1.set_xlabel('Epochs', fontsize=20)
ax1.set_xscale('log')
ax1.set_ylabel('MSE', fontsize=20)
ax1.legend(loc='lower left', bbox_to_anchor=(0, 0.2))

ax2 = ax1.twinx()
ax2.plot(interpolated_frob, color='green', label=f'Frobenius Norm')
ax2.legend(loc='upper left', bbox_to_anchor=(0, 0.58))
ax2.set_ylabel('Frobenius Norm', fontsize=20)
plt.tight_layout()
plt.savefig("figures/loss_vs_frobenius.png")
plt.show()

## MSE vs Condition Number

In [None]:
fig, ax1 = plt.subplots()

col = cmap(0)
ax1.plot(np.array(tr_losses), linestyle='--', label=rf'Train Loss', color=col, lw=2)
ax1.plot(np.array(te_losses), label=rf'Test Loss', color=col, lw=2)
ax1.set_xlabel('Epochs', fontsize=20)
ax1.set_xscale('log')
ax1.set_ylabel('MSE', fontsize=20)
ax1.legend(loc='lower left', bbox_to_anchor=(0, 0.1))

ax2 = ax1.twinx()
ax2.plot(interpolated_cond, linestyle='-', color='orange', label=r'Condition Number', lw=2)
ax2.legend(loc='upper left', bbox_to_anchor=(0, 0.38))
ax2.set_ylabel('Condition Number', fontsize=20)

plt.tight_layout()
plt.savefig("figures/loss_vs_condition.png") 
plt.show()