<a href="https://colab.research.google.com/github/bobby-he/Neural_Tangent_Kernel/blob/master/training_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bobby-he/Neural_Tangent_Kernel.git

Cloning into 'Neural_Tangent_Kernel'...
remote: Enumerating objects: 76, done.[K
remote: Counting objects: 100% (76/76), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 76 (delta 24), reused 28 (delta 8), pack-reused 0[K
Unpacking objects: 100% (76/76), done.


In [0]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [0]:
import seaborn as sns
sns.set()

In [0]:
from collections import OrderedDict

In [0]:
from Neural_Tangent_Kernel.src.NTK_net import LinearNeuralTangentKernel, FourLayersNet, train_net, circle_transform, variance_est

In [0]:
import copy
from google.colab import files

In [0]:
use_cuda = True if torch.cuda.is_available() else False

### define net and data

In [0]:
from matplotlib import animation, rc
from IPython.display import HTML
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set2.colors)

n_pts = 100
# animate over some set of x, y
gamma_vec = torch.tensor(np.linspace(-np.pi, np.pi, n_pts))
circle_test = circle_transform(gamma_vec).cuda()
gamma_data = torch.tensor(np.array([-2.2, -1, 1, 2.2]))
target_data = torch.tensor(np.array([-0.4, -0.2, 0.3, 0.3])).float()
input_data = circle_transform(gamma_data)
# note mean_vec has length 100 so better to set n_pts=100
mean_vec = np.load('Neural_Tangent_Kernel/data/mean_vec.npy').flatten()

In [0]:
net1 = FourLayersNet(n_wid=1000, n_out=1).cuda()
net2 = FourLayersNet(n_wid=1000, n_out=1).cuda()
net3 = FourLayersNet(n_wid=1000, n_out=1).cuda()
net4 = FourLayersNet(n_wid=1000, n_out=1).cuda()
net5 = FourLayersNet(n_wid=1000, n_out=1).cuda()

if use_cuda:
  input_data = input_data.cuda()
  target_data = target_data.cuda()
# First set up the figure, the axes, and the plot element
fig, ax = plt.subplots()
plt.close()
ax.set_xlim(( -np.pi, np.pi))
ax.set_ylim((-0.5, 0.5))
ax.set_xlabel('$\gamma$')
ax.set_ylabel('$f_{ \\theta}(sin(y),cos(y))$')

line0, = ax.plot([], [], lw=2, color = 'darkmagenta')
line1, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9)
line2, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9)
line3, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9)
line4, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9)
line5, = ax.plot([], [], lw=1, linestyle = '--', alpha = 0.9)
# initialization function: plot the background of each frame
def init():   
  line0.set_data(gamma_vec.numpy(), mean_vec)
  return (line0,)
  
# animation function: this is called sequentially
def animate(i):
  j = 0
  if i>2:
    train_net(net1, 1, input_data, target_data)
    train_net(net2, 1, input_data, target_data)
    train_net(net3, 1, input_data, target_data)
    train_net(net4, 1, input_data, target_data)
    train_net(net5, 1, input_data, target_data)
    j = i - 2
  
  line1.set_data(gamma_vec.numpy(), net1(circle_test).cpu().detach().numpy())
  line2.set_data(gamma_vec.numpy(), net2(circle_test).cpu().detach().numpy())
  line3.set_data(gamma_vec.numpy(), net3(circle_test).cpu().detach().numpy())
  line4.set_data(gamma_vec.numpy(), net4(circle_test).cpu().detach().numpy())
  line5.set_data(gamma_vec.numpy(), net5(circle_test).cpu().detach().numpy())  
  ax.set_title('Epoch {}'.format(j))
  return (line1, line2, line3, line4, line5, )

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=52, interval=150, blit=True)
rc('animation', html='jshtml')

If you want to see the animation, call the next cell

In [10]:
anim

If you want to save the animation, call the next cell before the cell above

In [0]:
anim.save('anim.mp4')
files.download('anim.mp4')
