<a href="https://colab.research.google.com/github/bobby-he/Neural_Tangent_Kernel/blob/master/notebooks/training_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Compare the training in a NN through backprop and using kernel gradients 

In [1]:
!git clone https://github.com/bobby-he/Neural_Tangent_Kernel.git

Cloning into 'Neural_Tangent_Kernel'...
remote: Enumerating objects: 159, done.[K
remote: Counting objects: 100% (159/159), done.[K
remote: Compressing objects: 100% (124/124), done.[K
remote: Total 159 (delta 64), reused 35 (delta 10), pack-reused 0[K
Receiving objects: 100% (159/159), 16.10 MiB | 12.47 MiB/s, done.
Resolving deltas: 100% (64/64), done.


In [0]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [0]:
import seaborn as sns
sns.set()

In [0]:
from collections import OrderedDict

In [0]:
import sys
sys.path.append("../")
from Neural_Tangent_Kernel.src.NTK_net import LinearNeuralTangentKernel, FourLayersNet, train_net, circle_transform, variance_est, cpu_tuple,\
                                              AnimationPlot_lsq, kernel_leastsq_update, kernel_mats

In [0]:
import copy
from google.colab import files

In [0]:
use_cuda = True if torch.cuda.is_available() else False

### define net and data

In [0]:
from matplotlib import animation, rc
from IPython.display import HTML
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set2.colors)

n_pts = 100
# animate over some set of x, y
gamma_vec = torch.tensor(np.linspace(-np.pi, np.pi, n_pts))
circle_test = circle_transform(gamma_vec).cuda()
gamma_data = torch.tensor(np.array([-2.2, -1, 1, 2.2]))
target_data = torch.tensor(np.array([-0.4, -0.2, 0.3, 0.3])).float()
input_data = circle_transform(gamma_data)
# note mean_vec has length 100 so better to set n_pts=100
mean_vec = np.load('Neural_Tangent_Kernel/data/mean_vec.npy').flatten()

### define variables for kernel gradient

In [0]:
n_width = 10000
net = FourLayersNet(n_width)
if use_cuda:
  net.cuda()

#train_net(net, 1000, input_data, target_data)


In [11]:
K_testvtrain, K_trainvtrain = kernel_mats(net, gamma_data, gamma_vec)

K_testvtrain is 10% complete
K_testvtrain is 20% complete
K_testvtrain is 30% complete
K_testvtrain is 40% complete
K_testvtrain is 50% complete
K_testvtrain is 60% complete
K_testvtrain is 70% complete
K_testvtrain is 80% complete
K_testvtrain is 90% complete
K_testvtrain is 100% complete


In [0]:
K_trainvtrain = K_trainvtrain.cpu().detach().numpy()
K_testvtrain = K_testvtrain.cpu().detach().numpy()

### define class structure for gradient updates in animation

### generate animation

In [0]:
if use_cuda:
  input_data = input_data.cuda()
  target_data = target_data.cuda()
# First set up the figure, the axes, and the plot element
fig, ax = plt.subplots()
plt.close()
ax.set_xlim(( -np.pi, np.pi))
ax.set_ylim((-0.5, 0.5))
ax.set_xlabel('$\gamma$')
ax.set_ylabel('$f_{ \\theta}(sin(\gamma),cos(\gamma))$')

line, = ax.plot([], [], lw=2, color = 'darkmagenta', label = 'Mean')
legend_line1 = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'black', label = 'Backprop')
legend_line1 = ax.plot([], [], lw=1, alpha = 0.9, color = 'black', label = 'Kernel GD')

line1, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'red')
line2, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'green')
line3, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'cyan')
line4, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'gold')
line0, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'darkorange')

line1a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'red')
line2a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'green')
line3a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'cyan')
line4a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'gold')
line0a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'darkorange')

line_tuple = (line1, line2, line3, line4, line0, line1a, line2a, line3a, line4a, line0a)

ax.legend(loc = 'upper left')
# initialization function: plot the background of each frame
def init():   
  line.set_data(gamma_vec.numpy(), mean_vec)
  return (line,)
  
# animation function: this is called sequentially
anim_plot = AnimationPlot_lsq(n_nets = 5, n_wid = 50, input_data = input_data, K_testvtrain = K_testvtrain, 
                              K_trainvtrain = K_trainvtrain, train_target = target_data, line_tuple = line_tuple, ax = ax)

anim = animation.FuncAnimation(fig, anim_plot.plot_step, init_func=init, frames=82, interval=150, blit=True)
rc('animation', html='jshtml')

In [19]:
anim

If you want to save the animation, call the next cell before the cell above

In [0]:
anim.save('anim_kernel.mp4')
files.download('anim_kernel.mp4')