In [None]:
import os
import sys

from scipy.interpolate import interp1d
import jax
import jax.numpy as jnp
import neural_tangents as nt
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from jax import grad, jit, jacfwd, jacrev, lax, random, vmap
from jax.example_libraries import optimizers
from neural_tangents import stax
from tqdm import tqdm

In [None]:
cmap = matplotlib.colormaps.get_cmap('tab20')

In [None]:
def phi(z, eps = 0.25):
    return z + 0.5*eps*z**2

def NN_func2(params, X, alpha, eps=0.25):
    a, W = params

    D = W.shape[1]
    N = a.shape[0]

    h = W @ X / jnp.sqrt(D)
    f = alpha * jnp.mean( phi(h, eps = eps), axis = 0)
    return f

def target_fn(beta, X):
    return (X.T @ beta / jnp.sqrt(D))**2

In [None]:
D = 100
P = 550
N = 500
ntk_interval = 100 

X = random.normal(random.PRNGKey(0), (D,P))
Xt = random.normal(random.PRNGKey(1), (D,1000))
beta = random.normal(random.PRNGKey(2), (D,))

y = target_fn(beta, X)
yt = target_fn(beta, Xt)

# Alpha

In [None]:
a = random.normal(random.PRNGKey(0), (N, ))
W = random.normal(random.PRNGKey(0), (N, D))
params = [a, W]

eps = 0.25
eta = 0.5 * N
lamb = 0.0
opt_init, opt_update, get_params = optimizers.sgd(eta)

alphas = [2**(-5),0.25,0.5,1.0,2.0,4.0,8.0,16,32]

all_tr_losses = []
all_te_losses = []
all_acc_tr = []
all_acc_te = []

param_movement = []
for alpha in alphas:
    def nn_wrapper(params, X):
        return NN_func2(params, X.T, alpha=alpha, eps=0.25)

    folder = f"kernels_alpha_sweep/alpha_{alpha}"
    os.makedirs(folder, exist_ok=True)
    
    ntk_fn = nt.empirical_ntk_fn(nn_wrapper, vmap_axes=0, trace_axes=())
    K_0 = ntk_fn(Xt.T, None, params)
    
    np.save(os.path.join(folder, "k_0"), K_0)

    opt_state = opt_init(params)
    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X, alpha) - y )**2 / alpha**2 ))
    acc_fn = jit(lambda p, X, y: jnp.mean( ( y * NN_func2(p, X,alpha)) > 0.0 ))
    reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )

    grad_loss = jit(grad(reg_loss,0))

    tr_losses = []
    te_losses = []
    tr_acc = []
    te_acc = []
    
    for t in range(60000):
        opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)
        if t % 2 == 0:
            train_loss = alpha**2*loss_fn(get_params(opt_state), X, y)
            test_loss = alpha**2*loss_fn(get_params(opt_state), Xt, yt)
            tr_losses += [train_loss]
            te_losses += [test_loss]
            tr_acc += [ acc_fn(get_params(opt_state), X, y) ]
            te_acc += [ acc_fn(get_params(opt_state), Xt, yt) ]
            # sys.stdout.write(f'\r t: {t} | train loss: {train_loss} | test loss: {test_loss}')
        if t % 10000 == 0:
            print(" ")

        if t % ntk_interval == 0:
            K_test = ntk_fn(Xt.T, None, get_params(opt_state))
            sys.stdout.write(f'\r  t: {t} | train loss: {train_loss} | test loss: {test_loss} |frob ||kt-k0||_2: {jnp.linalg.norm(K_test-K_0)}')
            np.save(os.path.join(folder, f"k_{t}"), K_test) 
    all_tr_losses += [tr_losses]
    all_te_losses += [te_losses]
    all_acc_tr += [tr_acc]
    all_acc_te += [te_acc]

    paramsf = get_params(opt_state)
    dparam = (jnp.sum((paramsf[0]-params[0])**2) + jnp.sum((paramsf[1]-params[1])**2)) / ( jnp.sum( params[0]**2 ) + jnp.sum(params[1]**2) )
    param_movement += [  dparam ]

    losses_folder = "kernels_alpha_sweep"
    np.save(os.path.join(losses_folder, "train_loss"), all_tr_losses)
    np.save(os.path.join(losses_folder, "test_loss"), all_te_losses)
    np.save(os.path.join(losses_folder, "train_accuracy"), all_acc_tr)
    np.save(os.path.join(losses_folder, "test_accuracy"), all_acc_te)

In [None]:
# free memory before next exp
del all_tr_losses
del all_te_losses
del all_acc_tr
del all_acc_te

# Weight Normalization

In [None]:
weight_norms = [0.125,0.25,0.5,1.0,2.0]
alpha = 1.0
eps = 0.25
eta = 0.5 * N
lamb = 0.0

all_tr_losses_w = []
all_te_losses_w = []
all_acc_tr_w = []
all_acc_te_w = []

param_movement_w = []

for i, wscale in enumerate(weight_norms):
    # setup wrapper and folder to save things
    def nn_wrapper(params, X):
        return NN_func2(params, X.T, alpha=alpha, eps=eps)
    w_folder = f"kernels_wnorm_sweep/{wscale}"
    os.makedirs(w_folder, exist_ok=True)

    a = wscale * random.normal(random.PRNGKey(0), (N, ))
    W = wscale * random.normal(random.PRNGKey(0), (N, D))
    params = [a, W]

    # save k_0 for reference
    ntk_fn = nt.empirical_ntk_fn(nn_wrapper, vmap_axes=0, trace_axes=())
    K_0 = ntk_fn(Xt.T, None, params)
    np.save(os.path.join(w_folder, "k_0"), K_0)

    opt_init, opt_update, get_params = optimizers.sgd( eta / wscale**2 )
    opt_state = opt_init(params)


    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p,X,alpha)- NN_func2(params,X,alpha) - y )**2 / alpha**2 ))
    acc_fn = jit(lambda p, X, y: jnp.mean( ( y * ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha)) ) > 0.0 ))
    reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )

    grad_loss = jit(grad(reg_loss,0))

    tr_losses = []
    te_losses = []
    tr_acc = []
    te_acc = []
    for t in range(50000):
        opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)

        if t % 2 == 0:
            train_loss = alpha**2*loss_fn(get_params(opt_state), X, y)
            test_loss = alpha**2*loss_fn(get_params(opt_state), Xt, yt)
            tr_losses += [train_loss]
            te_losses += [test_loss]
            tr_acc += [ acc_fn(get_params(opt_state), X, y) ]
            te_acc += [ acc_fn(get_params(opt_state), Xt, yt) ]
            #sys.stdout.write(f'\r t: {t} | train loss: {train_loss} | test loss: {test_loss}')
        if t % 10000 == 0:
            print(" ")

        if t % ntk_interval == 0:
            K_test = ntk_fn(Xt.T, None, get_params(opt_state))
            sys.stdout.write(f'\r  t: {t} | train loss: {train_loss} | test loss: {test_loss} |frob ||kt-k0||_F: {frobenius(K_test-K_0)}')
            np.save(os.path.join(w_folder, f"k_{t}"), K_test) 

    all_tr_losses_w += [tr_losses]
    all_te_losses_w += [te_losses]
    all_acc_tr_w += [tr_acc]
    all_acc_te_w += [te_acc]

    paramsf = get_params(opt_state)
    dparam = (jnp.sum((paramsf[0]-params[0])**2) + jnp.sum((paramsf[1]-params[1])**2)) / ( jnp.sum(params[0]**2) + jnp.sum(params[1]**2) )
    param_movement_w += [  dparam ]

    loss_folder = "kernels_wnorm_sweep"
    np.save(os.path.join(loss_folder, "train_loss"), all_tr_losses_w)
    np.save(os.path.join(loss_folder, "test_loss"), all_te_losses_w)
    np.save(os.path.join(loss_folder, "train_accuracy"), all_acc_tr_w)
    np.save(os.path.join(loss_folder, "test_accuracy"), all_acc_te_w)

In [None]:
# free memory before next exp
del all_tr_losses_w
del all_te_losses_w
del all_acc_tr_w
del all_acc_te_w

# Epsilon

In [None]:
D = 100
P = 550
N = 500
ntk_interval = 100 

X = random.normal(random.PRNGKey(0), (D,P))
Xt = random.normal(random.PRNGKey(1), (D,1000))
beta = random.normal(random.PRNGKey(2), (D,))

y = target_fn(beta, X)
yt = target_fn(beta, Xt)

a = random.normal(random.PRNGKey(0), (N, ))
W = random.normal(random.PRNGKey(0), (N, D))
params = [a, W]

alpha=1.0
eta = 0.5 * N
lamb = 0.0
opt_init, opt_update, get_params = optimizers.sgd(eta)

epsilons = [2**(-2),2**(-1),1,2,4]

all_tr_losses_eps = []
all_te_losses_eps = []
all_acc_tr_eps = []
all_acc_te_eps = []

param_movement = []

for eps in epsilons:
    def nn_wrapper(params, X):
        return NN_func2(params, X.T, alpha=alpha, eps=eps)
    eps_folder = f"kernels_eps_sweep/{eps}"
    os.makedirs(eps_folder, exist_ok=True)
    
    ntk_fn = nt.empirical_ntk_fn(nn_wrapper, vmap_axes=0, trace_axes=())
    K_0 = ntk_fn(Xt.T, None, params)
    np.save(os.path.join(eps_folder, "k_0"), K_0)

    opt_state = opt_init(params)
    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X, alpha=alpha, eps=eps) - y )**2 / alpha**2 ))
    acc_fn = jit(lambda p, X, y: jnp.mean( ( y * NN_func2(p, X, alpha=alpha, eps=eps)) > 0.0 ))
    reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )

    grad_loss = jit(grad(reg_loss,0))

    tr_losses = []
    te_losses = []
    tr_acc = []
    te_acc = []
    for t in range(60000):
        opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)
        if t % 2 == 0:
            train_loss = alpha**2*loss_fn(get_params(opt_state), X, y)
            test_loss = alpha**2*loss_fn(get_params(opt_state), Xt, yt)
            tr_losses += [train_loss]
            te_losses += [test_loss]
            tr_acc += [ acc_fn(get_params(opt_state), X, y) ]
            te_acc += [ acc_fn(get_params(opt_state), Xt, yt) ]
            # sys.stdout.write(f'\r t: {t} | train loss: {train_loss} | test loss: {test_loss}')
        if t % 10000 == 0:
            print(" ")

        if t % ntk_interval == 0:
            K_test = ntk_fn(Xt.T, None, get_params(opt_state))
            sys.stdout.write(f'\r  t: {t} | train loss: {train_loss} | test loss: {test_loss} |frob ||kt-k0||_F: {frobenius(K_test-K_0)}')
            np.save(os.path.join(eps_folder, f"k_{t}"), K_test) 
    all_tr_losses_eps += [tr_losses]
    all_te_losses_eps += [te_losses]
    all_acc_tr_eps += [tr_acc]
    all_acc_te_eps += [te_acc]

    paramsf = get_params(opt_state)
    dparam = (jnp.sum((paramsf[0]-params[0])**2) + jnp.sum((paramsf[1]-params[1])**2)) / ( jnp.sum( params[0]**2 ) + jnp.sum(params[1]**2) )
    param_movement += [dparam]

    loss_folder = "kernels_eps_sweep"
    np.save(os.path.join(loss_folder, "train_loss"), all_tr_losses_eps)
    np.save(os.path.join(loss_folder, "test_loss"), all_te_losses_eps)
    np.save(os.path.join(loss_folder, "train_accuracy"), all_acc_tr_eps)
    np.save(os.path.join(loss_folder, "test_accuracy"), all_acc_te_eps)

In [None]:
all_tr_losses_eps = np.load(os.path.join(eps_folder, "train_loss"))
all_te_losses_eps = np.load(os.path.join(eps_folder, "test_loss"))
all_acc_tr_losses_eps = np.load(os.path.join(eps_folder, "train_accuracy"))
all_acc_te_losses_eps = np.load(os.path.join(eps_folder, "test_accuracy"))

plt.rcParams.update({'font.size': 14})
plt.figure()
for i, eps in enumerate(epsilons):
    plt.plot(
        jnp.linspace(1,len(all_tr_losses_eps[i]),len(all_tr_losses_eps[i])), 
        jnp.array(all_tr_losses_eps[i]) / all_tr_losses_eps[i][0], 
        '--',  
        color = f'C{i}'
    )
    plt.plot(
        jnp.linspace(1,len(all_tr_losses_eps[i]),len(all_tr_losses_eps[i])), 
        jnp.array(all_te_losses_eps[i]) / all_te_losses_eps[i][0],  
        color = f'C{i}', 
        label = r'$\epsilon = 2^{%0.0f}$' % jnp.log2(eps)
    )

plt.xscale('log')
plt.xlabel(r'$t$',fontsize = 20)
plt.ylabel('Loss',fontsize = 20)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig("figures/eps_sweep.png", bbox_inches="tight")
plt.show()

In [None]:
# free memory before next exp
del all_tr_losses_eps
del all_te_losses_eps
del all_acc_tr_eps
del all_acc_te_eps