<a href="https://colab.research.google.com/github/bobby-he/Neural_Tangent_Kernel/blob/master/training_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Compare the training in a NN through backprop and using kernel gradients 

In [1]:
!git clone https://github.com/bobby-he/Neural_Tangent_Kernel.git

fatal: destination path 'Neural_Tangent_Kernel' already exists and is not an empty directory.


In [0]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [0]:
import seaborn as sns
sns.set()

In [0]:
from collections import OrderedDict

In [0]:
from Neural_Tangent_Kernel.src.NTK_net import LinearNeuralTangentKernel, FourLayersNet, train_net, circle_transform, variance_est, cpu_tuple

In [0]:
import copy
from google.colab import files

In [0]:
use_cuda = True if torch.cuda.is_available() else False

### define net and data

In [0]:
from matplotlib import animation, rc
from IPython.display import HTML
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set2.colors)

n_pts = 100
# animate over some set of x, y
gamma_vec = torch.tensor(np.linspace(-np.pi, np.pi, n_pts))
circle_test = circle_transform(gamma_vec).cuda()
gamma_data = torch.tensor(np.array([-2.2, -1, 1, 2.2]))
target_data = torch.tensor(np.array([-0.4, -0.2, 0.3, 0.3])).float()
input_data = circle_transform(gamma_data)
# note mean_vec has length 100 so better to set n_pts=100
mean_vec = np.load('Neural_Tangent_Kernel/data/mean_vec.npy').flatten()

### define variables for kernel gradient

In [0]:
def kernel_leastsq_update(test_output, train_output, K_testvtrain, K_trainvtrain, train_target, n_steps = 1, learning_rate = 1): 
  test_output = test_output + np.matmul(K_testvtrain, train_target - train_output).flatten()
  train_output = train_output + np.matmul(K_trainvtrain, train_target - train_output).flatten()
  return test_output, train_output

In [0]:
n_width = 10000
net = FourLayersNet(n_width)
if use_cuda:
  net.cuda()

#train_net(net, 1000, input_data, target_data)


In [0]:
grad_list = []
for gamma in gamma_data:
  circle_pt = circle_transform(gamma)
  if use_cuda and torch.cuda.is_available():
    circle_pt = circle_pt.cuda()
  loss = net(circle_pt)
  grad_list.append(cpu_tuple(torch.autograd.grad(loss,net.parameters(), retain_graph = True)))

In [12]:
%%time
# K_testvtrain is kappa on p7 of NTK paper
K_testvtrain = torch.zeros((n_pts,4))
for i, gamma in enumerate(gamma_vec):
  if i%10 == 0:
    print('point {}'.format(i))
  circle_pt = circle_transform(gamma)
  if use_cuda and torch.cuda.is_available():
    circle_pt = circle_pt.cuda()
  loss = net(circle_pt)
  grads = cpu_tuple(torch.autograd.grad(loss,net.parameters(), retain_graph = True)) # extract NN gradients 
  for j in range(len(grad_list)):
    pt_grad = grad_list[j] # the gradients at the jth (out of 4) data point
    K_testvtrain[i, j] = sum([torch.sum(torch.mul(grads[u], pt_grad[u])) for u in range(len(grads))])
K_testvtrain = K_testvtrain.cpu().detach().numpy()

point 0
point 10
point 20
point 30
point 40
point 50
point 60
point 70
point 80
point 90
CPU times: user 3min 42s, sys: 5.75 s, total: 3min 47s
Wall time: 3min 48s


In [0]:
# let's create Ktilde matrix from p7 of NTK paper
K_trainvtrain = torch.zeros((4,4))
for i in range(4):
  grad_i = grad_list[i]
  for j in range(i+1):
    grad_j = grad_list[j]
    K_trainvtrain[i, j] = sum([torch.sum(torch.mul(grad_i[u], grad_j[u])) for u in range(len(grad_j))])
    K_trainvtrain[j, i] = K_trainvtrain[i, j]
K_trainvtrain = K_trainvtrain.cpu().detach().numpy()

### generate animation

In [0]:
aasdfsdf=1
bsadfasdf=2
global aasdfsdf,bsadfasdf

In [0]:
net1 = FourLayersNet(n_wid=1000, n_out=1).cuda()
copynet1 = copy.deepcopy(net1)
test_output_1 = net1(circle_test.cuda()).cpu().detach().numpy().flatten()
train_output_1 = net1(input_data.cuda()).cpu().detach().numpy().flatten()

net2 = FourLayersNet(n_wid=1000, n_out=1).cuda()
copynet2 = copy.deepcopy(net2)
test_output_2 = net2(circle_test.cuda()).cpu().detach().numpy().flatten()
train_output_2 = net2(input_data.cuda()).cpu().detach().numpy().flatten()

net3 = FourLayersNet(n_wid=1000, n_out=1).cuda()
copynet3 = copy.deepcopy(net3)
test_output_3 = net3(circle_test.cuda()).cpu().detach().numpy().flatten()
train_output_3 = net3(input_data.cuda()).cpu().detach().numpy().flatten()

net4 = FourLayersNet(n_wid=1000, n_out=1).cuda()
copynet4 = copy.deepcopy(net4)
test_output_4 = net4(circle_test.cuda()).cpu().detach().numpy().flatten()
train_output_4 = net4(input_data.cuda()).cpu().detach().numpy().flatten()

net5 = FourLayersNet(n_wid=1000, n_out=1).cuda()
copynet5 = copy.deepcopy(net5)
test_output_5 = net5(circle_test.cuda()).cpu().detach().numpy().flatten()
train_output_5 = net5(input_data.cuda()).cpu().detach().numpy().flatten()

global test_output_1, test_output_2, test_output_3, test_output_4, test_output_5
global train_output_1, train_output_2, train_output_3, train_output_4, train_output_5

if use_cuda:
  input_data = input_data.cuda()
  target_data = target_data.cuda()
# First set up the figure, the axes, and the plot element
fig, ax = plt.subplots()
plt.close()
ax.set_xlim(( -np.pi, np.pi))
ax.set_ylim((-0.5, 0.5))
ax.set_xlabel('$\gamma$')
ax.set_ylabel('$f_{ \\theta}(sin(y),cos(y))$')

line0, = ax.plot([], [], lw=2, color = 'darkmagenta')
line1, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'red')
line2, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'green')
line3, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'cyan')
line4, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'gold')
line5, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'darkorange')

line1a, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'red')
line2a, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'green')
line3a, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'cyan')
line4a, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'gold')
line5a, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9, color = 'darkorange')

# initialization function: plot the background of each frame
def init():   
  line0.set_data(gamma_vec.numpy(), mean_vec)
  return (line0,)
  
# animation function: this is called sequentially
def animate(i):
  j = 0
  if i>2:
    train_net(net1, 1, input_data, target_data)
    train_net(net2, 1, input_data, target_data)
    train_net(net3, 1, input_data, target_data)
    train_net(net4, 1, input_data, target_data)
    train_net(net5, 1, input_data, target_data)
    j = i - 2
  
  if i == 0:
    test_output_1 = copynet1(circle_test.cuda()).cpu().detach().numpy().flatten()
    train_output_1 = copynet1(input_data.cuda()).cpu().detach().numpy().flatten()
    
    test_output_2 = copynet2(circle_test.cuda()).cpu().detach().numpy().flatten()
    train_output_2 = copynet2(input_data.cuda()).cpu().detach().numpy().flatten()

    test_output_3 = copynet3(circle_test.cuda()).cpu().detach().numpy().flatten()
    train_output_3 = copynet3(input_data.cuda()).cpu().detach().numpy().flatten()
    
    test_output_4 = copynet4(circle_test.cuda()).cpu().detach().numpy().flatten()
    train_output_4 = copynet4(input_data.cuda()).cpu().detach().numpy().flatten()
    
    test_output_5 = copynet5(circle_test.cuda()).cpu().detach().numpy().flatten()
    train_output_5 = copynet5(input_data.cuda()).cpu().detach().numpy().flatten()
  
  test_output_1, train_output_1 = kernel_leastsq_update(test_output_1, train_output_1, K_testvtrain,
                      K_trainvtrain, target_data.cpu().detach().numpy(), n_steps = 1, learning_rate = 1)
  test_output_2, train_output_2 = kernel_leastsq_update(test_output_2, train_output_2, K_testvtrain,
                      K_trainvtrain, target_data.cpu().detach().numpy(), n_steps = 1, learning_rate = 1)
  
  test_output_3, train_output_3 = kernel_leastsq_update(test_output_3, train_output_3, K_testvtrain,
                      K_trainvtrain, target_data.cpu().detach().numpy(), n_steps = 1, learning_rate = 1)

  test_output_4, train_output_4 = kernel_leastsq_update(test_output_4, train_output_4, K_testvtrain,
                      K_trainvtrain, target_data.cpu().detach().numpy(), n_steps = 1, learning_rate = 1)
  
  test_output_5, train_output_5 = kernel_leastsq_update(test_output_5, train_output_5, K_testvtrain,
                      K_trainvtrain, target_data.cpu().detach().numpy(), n_steps = 1, learning_rate = 1)
  
  line1.set_data(gamma_vec.numpy(), net1(circle_test).cpu().detach().numpy())
  line2.set_data(gamma_vec.numpy(), net2(circle_test).cpu().detach().numpy())
  line3.set_data(gamma_vec.numpy(), net3(circle_test).cpu().detach().numpy())
  line4.set_data(gamma_vec.numpy(), net4(circle_test).cpu().detach().numpy())
  line5.set_data(gamma_vec.numpy(), net5(circle_test).cpu().detach().numpy())
  
  line1a.set_data(gamma_vec.numpy(), test_output_1)
  line2a.set_data(gamma_vec.numpy(), test_output_2)
  line3a.set_data(gamma_vec.numpy(), test_output_3)
  line4a.set_data(gamma_vec.numpy(), test_output_4)
  line5a.set_data(gamma_vec.numpy(), test_output_5)
  
  ax.set_title('Epoch {}'.format(j))
  return (line1, line2, line3, line4, line5, )

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=52, interval=150, blit=True)
rc('animation', html='jshtml')

If you want to see the animation, call the next cell

the above needs work, write a class in order to avoid global variables

In [23]:
anim

UnboundLocalError: ignored

<matplotlib.animation.FuncAnimation at 0x7f2a3612eeb8>

If you want to save the animation, call the next cell before the cell above

In [18]:
anim.save('anim.mp4')
files.download('anim.mp4')


UnboundLocalError: ignored