<a href="https://colab.research.google.com/github/bobby-he/Neural_Tangent_Kernel/blob/master/training_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Compare the training in a NN through backprop and using kernel gradients 

In [0]:
!git clone https://github.com/bobby-he/Neural_Tangent_Kernel.git

Cloning into 'Neural_Tangent_Kernel'...
remote: Enumerating objects: 131, done.[K
remote: Counting objects: 100% (131/131), done.[K
remote: Compressing objects: 100% (101/101), done.[K
remote: Total 131 (delta 51), reused 28 (delta 8), pack-reused 0[K
Receiving objects: 100% (131/131), 10.37 MiB | 6.72 MiB/s, done.
Resolving deltas: 100% (51/51), done.


In [0]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [0]:
import seaborn as sns
sns.set()

In [0]:
from collections import OrderedDict

In [0]:
import sys
sys.path.append("../")
from Neural_Tangent_Kernel.src.NTK_net import LinearNeuralTangentKernel, FourLayersNet, train_net, circle_transform, variance_est, cpu_tuple,\
                                              AnimationPlot_lsq, kernel_leastsq_update

In [0]:
import copy
from google.colab import files

In [0]:
use_cuda = True if torch.cuda.is_available() else False

### define net and data

In [0]:
from matplotlib import animation, rc
from IPython.display import HTML
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set2.colors)

n_pts = 100
# animate over some set of x, y
gamma_vec = torch.tensor(np.linspace(-np.pi, np.pi, n_pts))
circle_test = circle_transform(gamma_vec).cuda()
gamma_data = torch.tensor(np.array([-2.2, -1, 1, 2.2]))
target_data = torch.tensor(np.array([-0.4, -0.2, 0.3, 0.3])).float()
input_data = circle_transform(gamma_data)
# note mean_vec has length 100 so better to set n_pts=100
mean_vec = np.load('Neural_Tangent_Kernel/data/mean_vec.npy').flatten()

### define variables for kernel gradient

In [0]:
n_width = 10000
net = FourLayersNet(n_width)
if use_cuda:
  net.cuda()

#train_net(net, 1000, input_data, target_data)


In [0]:
grad_list = []
for gamma in gamma_data:
  circle_pt = circle_transform(gamma)
  if use_cuda and torch.cuda.is_available():
    circle_pt = circle_pt.cuda()
  loss = net(circle_pt)
  grad_list.append(cpu_tuple(torch.autograd.grad(loss,net.parameters(), retain_graph = True)))

In [0]:
%%time
# K_testvtrain is kappa on p7 of NTK paper
K_testvtrain = torch.zeros((n_pts,4))
for i, gamma in enumerate(gamma_vec):
  if i%10 == 0:
    print('point {}'.format(i))
  circle_pt = circle_transform(gamma)
  if use_cuda and torch.cuda.is_available():
    circle_pt = circle_pt.cuda()
  loss = net(circle_pt)
  grads = cpu_tuple(torch.autograd.grad(loss,net.parameters(), retain_graph = True)) # extract NN gradients 
  for j in range(len(grad_list)):
    pt_grad = grad_list[j] # the gradients at the jth (out of 4) data point
    K_testvtrain[i, j] = sum([torch.sum(torch.mul(grads[u], pt_grad[u])) for u in range(len(grads))])
K_testvtrain = K_testvtrain.cpu().detach().numpy()

point 0
point 10
point 20
point 30
point 40
point 50
point 60
point 70
point 80
point 90
CPU times: user 3min 42s, sys: 5 s, total: 3min 47s
Wall time: 3min 47s


In [0]:
# let's create Ktilde matrix from p7 of NTK paper
K_trainvtrain = torch.zeros((4,4))
for i in range(4):
  grad_i = grad_list[i]
  for j in range(i+1):
    grad_j = grad_list[j]
    K_trainvtrain[i, j] = sum([torch.sum(torch.mul(grad_i[u], grad_j[u])) for u in range(len(grad_j))])
    K_trainvtrain[j, i] = K_trainvtrain[i, j]
K_trainvtrain = K_trainvtrain.cpu().detach().numpy()

### define class structure for gradient updates in animation

### generate animation

In [0]:
if use_cuda:
  input_data = input_data.cuda()
  target_data = target_data.cuda()
# First set up the figure, the axes, and the plot element
fig, ax = plt.subplots()
plt.close()
ax.set_xlim(( -np.pi, np.pi))
ax.set_ylim((-0.5, 0.5))
ax.set_xlabel('$\gamma$')
ax.set_ylabel('$f_{ \\theta}(sin(\gamma),cos(\gamma))$')

line, = ax.plot([], [], lw=2, color = 'darkmagenta', label = 'Mean')
legend_line1 = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'black', label = 'Backprop')
legend_line1 = ax.plot([], [], lw=1, alpha = 0.9, color = 'black', label = 'Kernel GD')

line1, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'red')
line2, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'green')
line3, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'cyan')
line4, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'gold')
line0, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'darkorange')

line1a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'red')
line2a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'green')
line3a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'cyan')
line4a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'gold')
line0a, = ax.plot([], [], lw=1, alpha = 0.9, color = 'darkorange')

line_tuple = (line1, line2, line3, line4, line0, line1a, line2a, line3a, line4a, line0a)

ax.legend(loc = 'upper left')
# initialization function: plot the background of each frame
def init():   
  line.set_data(gamma_vec.numpy(), mean_vec)
  return (line,)
  
# animation function: this is called sequentially
anim_plot = AnimationPlot_lsq(n_nets = 5, n_wid = 50, input_data = input_data, K_testvtrain = K_testvtrain, 
                              K_trainvtrain = K_trainvtrain, train_target = target_data, line_tuple = line_tuple, ax = ax)

anim = animation.FuncAnimation(fig, anim_plot.plot_step, init_func=init, frames=82, interval=150, blit=True)
rc('animation', html='jshtml')

In [0]:
anim

If you want to save the animation, call the next cell before the cell above

In [0]:
anim.save('anim_kernel.mp4')
files.download('anim_kernel.mp4')