In [None]:
import glob
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from scipy.interpolate import interp1d

In [None]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

# Compute $K_{\text{diff}}$ if it doesn't exist.

In [None]:
if not os.path.isfile("kernels_alpha_sweep/frob_k.npy"):
    alphas = [2**(-5),0.25,0.5,1.0,2.0,4.0,8.0,16,32]
    alpha_frob = []
    ntk_interval = 100
    
    for i, alpha in enumerate(alphas):
        print(f"Alpha = {alpha}")
        kernels_list = glob.glob(f"kernels_alpha_sweep/alpha_{alpha}/k_*.npy")
        kernels_list.sort(key=natural_keys)
        
        k_0 = np.load(kernels_list[0])
        k_test = kernels_list
        t = [int(k.split("_")[-1][:-4]) for k in kernels_list]
        print("Calculating Frobenius norm")
        all_fb = []
        for k in k_test:
            k_t = np.load(k)
            fb = jnp.linalg.norm(k_t-k_0)
            all_fb.append(fb)
    
        print("Interpolating")
        # interpolate correct intervals to plot
        t_values = np.arange(0, t[-1]+1, ntk_interval)
        interpolator = interp1d(t_values, all_fb, kind='linear', fill_value='extrapolate')
        interpolated_frob = interpolator(np.arange(max(t)))
    
        alpha_frob.append(interpolated_frob)
    
    np.save("kernels_alpha_sweep/frob_k.npy", alpha_frob)

In [None]:
if not os.path.isfile("kernels_wnorm_sweep/frob_k.npy"):
    weight_norms = [0.125,0.25,0.5,1.0,2.0]
    w_frob = []
    ntk_interval = 100
    w_frob = np.load("kernels_wnorm_sweep/frob_k.npy")
    
    for i, wnorm in enumerate(weight_norms):
        print(f"Weigth = {wnorm}")
        kernels_list = glob.glob(f"kernels_wnorm_sweep/{wnorm}/k_*.npy")
        kernels_list.sort(key=natural_keys)
        
        k_0 = np.load(kernels_list[0])
        k_test = kernels_list
        t = [int(k.split("_")[-1][:-4]) for k in kernels_list]
        print("Calculating Frobenius norm")
        all_fb = []
        for k in k_test:
            k_t = np.load(k)
            fb = frobenius(k_t-k_0)
            all_fb.append(fb)
    
        print("Interpolating")
        # interpolate correct intervals to plot
        t_values = np.arange(0, t[-1]+1, ntk_interval)
        interpolator = interp1d(t_values, all_fb, kind='linear', fill_value='extrapolate')
        interpolated_frob = interpolator(np.arange(max(t)))
    
        w_frob.append(interpolated_frob)
    
    np.save("kernels_wnorm_sweep/frob_k.npy", w_frob)

In [None]:
if not os.path.isfile("kernels_eps_sweep/frob_k.npy"):
    epsilons = [2**(-2),2**(-1),1,2,4]
    eps_frob = []
    ntk_interval = 100
    
    for i, eps in enumerate(epsilons):
        print(f"Eps = {eps}")
        kernels_list = glob.glob(f"kernels_eps_sweep/{eps}/k_*.npy")
        kernels_list.sort(key=natural_keys)
        
        k_0 = np.load(kernels_list[0])
        k_test = kernels_list
        t = [int(k.split("_")[-1][:-4]) for k in kernels_list]
        print("Calculating Frobenius norm")
        all_fb = []
        for k in k_test:
            k_t = np.load(k)
            fb = frobenius(k_t-k_0)
            all_fb.append(fb)
    
        print("Interpolating")
        # interpolate correct intervals to plot
        t_values = np.arange(0, t[-1]+1, ntk_interval)
        interpolator = interp1d(t_values, all_fb, kind='linear', fill_value='extrapolate')
        interpolated_frob = interpolator(np.arange(max(t)))
    
        eps_frob.append(interpolated_frob)
    
    np.save("kernels_eps_sweep/frob_k.npy", eps_frob)

# Sweep

## Alpha Sweep

In [None]:
folder = f"kernels_alpha_sweep"
all_tr_losses = np.load(os.path.join(folder, "train_loss.npy"))
all_te_losses = np.load(os.path.join(folder, "test_loss.npy"))
all_acc_tr_losses = np.load(os.path.join(folder, "train_accuracy.npy"))
all_acc_te_losses = np.load(os.path.join(folder, "test_accuracy.npy"))

plt.rcParams.update({'font.size': 14})
fig, ax = plt.subplots(1,2, figsize=(10,5), sharex=True)

alphas = [2**(-5),0.25,0.5,1.0,2.0,4.0,8.0,16,32]
for i, alpha in enumerate(alphas[:-1]):    
    #print(alpha)
    ax[0].plot(
        jnp.linspace(1,len(all_tr_losses[i]),len(all_tr_losses[i])), 
        jnp.array(all_tr_losses[i]) / all_tr_losses[i][0], 
        '--',  
        color = f'C{i}'
    )
    ax[0].plot(
        jnp.linspace(1,len(all_tr_losses[i]),len(all_tr_losses[i])), 
        jnp.array(all_te_losses[i]) / all_te_losses[i][0], 
        color = f'C{i}', 
        label = r'$\alpha = 2^{%0.0f}$' % jnp.log2(alpha)
    )

    ax[1].plot(alpha_frob[i], label = r'$\alpha = 2^{%0.0f}$' % jnp.log2(alpha))

ax[0].set_xscale('log')
ax[0].set_xlabel(r'$t$', fontsize=20)
ax[0].set_ylabel('Loss', fontsize=20)
ax[1].set_xlabel(r'$t$', fontsize=20)
ax[1].set_ylabel(r'$\|K_t - K_0\|_F$',fontsize = 20)
ax[1].legend(loc='upper left')

plt.tight_layout()
#plt.savefig("figures/alpha_sweep.png") #, bbox_inches="tight")
plt.show()  

## Weight Norms

In [None]:
weight_norms = [0.125,0.25,0.5,1.0,2.0]
w_folder = "kernels_wnorm_sweep"

all_tr_losses_w = np.load(os.path.join(w_folder, "train_loss.npy"))
all_te_losses_w = np.load(os.path.join(w_folder, "test_loss.npy"))
all_acc_tr_losses_w = np.load(os.path.join(w_folder, "train_accuracy.npy"))
all_acc_te_losses_w = np.load(os.path.join(w_folder, "test_accuracy.npy"))

plt.rcParams.update({'font.size': 14})
fig, ax = plt.subplots(1,2,figsize=(10,5),sharex=True)
for i, wscale in enumerate(weight_norms):
    print(wscale)
    ax[0].plot(
        jnp.linspace(1,len(all_tr_losses_w[i]),len(all_tr_losses_w[i])), 
        jnp.array(all_tr_losses_w[i]) / all_tr_losses_w[i][0], 
        '--',  
        color = f'C{i}'
    )
    ax[0].plot(
        jnp.linspace(1,len(all_tr_losses_w[i]),len(all_tr_losses_w[i])), 
        jnp.array(all_te_losses_w[i]) / all_te_losses_w[i][0],  
        color = f'C{i}', 
        label = r'$\sigma = 2^{%0.0f}$' % jnp.log2(wscale)
    )

    ax[1].plot(w_frob[i], label = r'$\sigma = 2^{%0.0f}$' % jnp.log2(wscale))

ax[0].set_xscale('log')
ax[0].set_xlabel(r'$t$', fontsize=20)
ax[0].set_ylabel('Loss', fontsize=20)

ax[1].set_xlabel(r'$t$', fontsize=20)
ax[1].set_ylabel(r'$\|K_t - K_0\|_F$',fontsize = 20)
ax[1].legend(loc='upper left')

plt.tight_layout()
plt.savefig("figures/weightnorm_sweep.png") #, bbox_inches="tight")
plt.show()  

## Epsilon

In [None]:
epsilons = [2**(-2),2**(-1),1,2,4]
eps_folder = "kernels_eps_sweep"
all_tr_losses_eps = np.load(os.path.join(eps_folder, "train_loss.npy"))
all_te_losses_eps = np.load(os.path.join(eps_folder, "test_loss.npy"))
all_acc_tr_losses_eps = np.load(os.path.join(eps_folder, "train_accuracy.npy"))
all_acc_te_losses_eps = np.load(os.path.join(eps_folder, "test_accuracy.npy"))

plt.rcParams.update({'font.size': 14})
fig, ax = plt.subplots(1,2,figsize=(10,5),sharex=True)
for i, eps in enumerate(epsilons):
    ax[0].plot(
        jnp.linspace(1,len(all_tr_losses_eps[i]),len(all_tr_losses_eps[i])), 
        jnp.array(all_tr_losses_eps[i]) / all_tr_losses_eps[i][0], 
        '--',  
        color = f'C{i}'
    )
    ax[0].plot(
        jnp.linspace(1,len(all_tr_losses_eps[i]),len(all_tr_losses_eps[i])), 
        jnp.array(all_te_losses_eps[i]) / all_te_losses_eps[i][0],  
        color = f'C{i}', 
        label = r'$\epsilon = 2^{%0.0f}$' % jnp.log2(eps)
    )

    ax[1].plot(eps_frob[i], label = r'$\epsilon = 2^{%0.0f}$' % jnp.log2(eps))

ax[0].set_xscale('log')
ax[0].set_xlabel(r'$t$', fontsize=20)
ax[0].set_ylabel('Loss', fontsize=20)
ax[1].set_xlabel(r'$t$', fontsize=20)
ax[1].set_ylabel(r'$\|K_t - K_0\|_F$',fontsize = 20)
ax[1].legend(loc='upper left')

plt.tight_layout()
plt.savefig("figures/eps_sweep.png", bbox_inches="tight")
plt.show()