<a href="https://colab.research.google.com/github/bobby-he/Neural_Tangent_Kernel/blob/master/notebooks/GANimation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from scipy.special import logit, expit
!git clone https://github.com/bobby-he/Neural_Tangent_Kernel.git

fatal: destination path 'Neural_Tangent_Kernel' already exists and is not an empty directory.


In [0]:

from Neural_Tangent_Kernel.src.NTK_net import LinearNeuralTangentKernel, FourLayersNet, train_net, circle_transform, variance_est, cpu_tuple,\
                                              AnimationPlot_lsq, kernel_leastsq_update, kernel_mats, kernel_mats_d_gan
import time
import copy
from google.colab import files

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from scipy.stats import norm
from matplotlib import animation, rc
from IPython.display import HTML
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set2.colors)
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (n, 1)))  # Gaussian

def get_generator_input_sampler():
    return lambda n: 2*np.pi * torch.rand(n, 1)- np.pi  # Uniform-dist data into generator, _NOT_ Gaussian

#def get_generator_input_sampler():
  #return lambda n: torch.Tensor(np.random.normal(0, 1, (n, 1)))  # Uniform-dist data into generator, _NOT_ Gaussian

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = LinearNeuralTangentKernel(input_size, hidden_size, w_sig = 1, beta = 0.1)
        self.map2 = LinearNeuralTangentKernel(hidden_size, hidden_size, w_sig = np.sqrt(5), beta = 0.1)
        self.map3 = LinearNeuralTangentKernel(hidden_size, output_size, w_sig = np.sqrt(5), beta = 0.1)
        #self.map4 = LinearNeuralTangentKernel(hidden_size, output_size, w_sig = np.sqrt(5), beta = 0.1)
        self.f = f

    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        x = self.map3(x)
        #x = self.f(x)
        #x = self.map4(x)
        return x
      
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = LinearNeuralTangentKernel(input_size, hidden_size, w_sig = 1)
        self.map2 = LinearNeuralTangentKernel(hidden_size, hidden_size, w_sig = np.sqrt(10))
        self.map3 = LinearNeuralTangentKernel(hidden_size, output_size, w_sig = np.sqrt(10))
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        x = self.map3(x)
        return x
      
class Discriminator_no_sig(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator_no_sig, self).__init__()
        self.map1 = LinearNeuralTangentKernel(input_size, hidden_size, w_sig = 1)
        self.map2 = LinearNeuralTangentKernel(hidden_size, hidden_size, w_sig = np.sqrt(10))
        self.map3 = LinearNeuralTangentKernel(hidden_size, output_size, w_sig = np.sqrt(10))
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        x = self.map3(x)
        return x

In [0]:
class GANimation(object):
  def __init__(self, generator, discriminator, line_tuple, fig, ax, g_learning_rate = 0.001, d_learning_rate = 0.001,
               momentum = 0.9, minibatch_size = 10, dis_iterations = 5, g_iterations = 1, n_pts = 100,
               print_every = 50, epochs_per_frame = 100, data_mean = 4, data_stddev = 1.25, noise_prop = 0.05):
    # Assume CUDA is available
    self.d_learning_rate = d_learning_rate
    self.g_learning_rate = g_learning_rate
    self.G = generator
    self.G_opt = optim.SGD(self.G.parameters(), lr=g_learning_rate, momentum=momentum)
    self.D = discriminator
    self.D_opt = optim.SGD(self.D.parameters(), lr=d_learning_rate, momentum=momentum)
    self.dis_iterations = dis_iterations
    self.g_iterations = g_iterations
    self.minibatch_size = minibatch_size
    self.data_mean = data_mean
    self.data_stddev = data_stddev
    self.g_backprop_line, self.d_backprop_line, self.g_kernel_line, self.d_kernel_line  = line_tuple #self.d_kernel_line
    self.ax = ax
    self.fig = fig
    self.n_pts = n_pts
    self.g_test_data = torch.tensor(np.linspace(-np.pi, np.pi, n_pts)).float()
    self.d_test_data = torch.tensor(np.linspace(data_mean - 4*data_stddev, 
                                                data_mean + 4 *data_stddev, n_pts)).float().reshape((self.n_pts,1))
    self.print_every = print_every
    self.epochs_per_frame = epochs_per_frame
    
    self.criterion = nn.BCELoss() 
    self.valid = Variable(torch.ones([minibatch_size,1]), requires_grad=False).cuda()
    self.fake = Variable(torch.zeros([minibatch_size,1]),  requires_grad=False).cuda()

    self.g_lr = g_learning_rate
    self.d_lr = d_learning_rate
    self.momentum = momentum
    
    
    self.G.cuda()
    self.D.cuda()
    
    self.d_sampler = get_distribution_sampler(self.data_mean, self.data_stddev)
    self.g_sampler = get_generator_input_sampler()
    
    # and finally, the kernel approximation
    self.g_test_circle = circle_transform(self.g_test_data).cuda()
    self.g_kernel_output = self.G(self.g_test_circle).cpu().detach().numpy()
    self.g_prev_kernel_output = self.G(self.g_test_circle).cpu().detach().numpy()
    
    self.G_copy = copy.deepcopy(self.G)
    self.D_no_sig = copy.deepcopy(self.D)
    
    self.d_kernel_output = self.D_no_sig(self.d_test_data.cuda()).cpu().detach().numpy()
    self.d_prev_kernel_output = self.d_kernel_output
    
  def _generator_train_step(self, fake_data, noisy_labels):
    self.G_opt.zero_grad()
    noisy_labels = torch.tensor(noisy_labels)
   
    d_fake = torch.sigmoid(self.D(fake_data))

    # calculate loss and optimise
    g_loss = self.criterion(d_fake, noisy_labels.cuda())
    g_loss.backward(retain_graph = True)
    self.G_opt.step()
    
    return d_fake
  
  def _discriminator_train_step(self, real_data, fake_data, noisy_real_labels, noisy_fake_labels):
    # save factors used in label smoothing
    noisy_real_labels = torch.tensor(noisy_real_labels)
    noisy_fake_labels = torch.tensor(noisy_fake_labels)
    self.D_opt.zero_grad()
    d_fake = torch.sigmoid(self.D(fake_data))
    d_real = torch.sigmoid(self.D(real_data))

    # calculate loss and optimise
    real_loss = self.criterion(d_real, noisy_real_labels.cuda())
    fake_loss = self.criterion(d_fake, noisy_fake_labels.cuda())
    d_loss = real_loss + fake_loss

    d_loss.backward()
    self.D_opt.step()
    
    return d_fake.cpu().detach().numpy(), d_real.cpu().detach().numpy()

  def plot_train_step(self, i):
    j = 0
    if i%10==0 and i!=0:
      print('{} steps gone'.format(i) )
    if i>2:
      for epoch in range(self.epochs_per_frame):
        for dis_update in range(self.dis_iterations):
          real_data = self.d_sampler(self.minibatch_size).cuda()
          gen_samples = self.g_sampler(self.minibatch_size)
          fake_data = self.G(circle_transform(gen_samples).reshape(self.minibatch_size, 2).cuda())
          noisy_real_labels = np.random.uniform(low = 0, high = 0.1, size = (self.minibatch_size,1)).astype(np.float32)
          noisy_fake_labels = 1 - np.random.uniform(low = 0, high = 0.1, size = (self.minibatch_size,1)).astype(np.float32)
          flipped_idx = np.random.choice(np.arange(self.minibatch_size), size=1)
          noisy_real_labels[flipped_idx] = 1 - noisy_real_labels[flipped_idx]
          noisy_fake_labels[flipped_idx] = 1 - noisy_fake_labels[flipped_idx]
          d_fake, d_real = self._discriminator_train_step(real_data, fake_data, noisy_real_labels, noisy_fake_labels)
          
          real_k_testvtrain = kernel_mats_d_gan(self.D_no_sig, real_data, self.d_test_data, use_cuda = True, kernels='testvtrain').cpu().detach().numpy()
          fake_k_testvtrain = kernel_mats_d_gan(self.D_no_sig, fake_data, self.d_test_data, use_cuda = True, kernels='testvtrain').cpu().detach().numpy()
          
          # now apply kernel update, first defining a temporary vector that becomes self.d_kernel_output
          temp = self.d_kernel_output + self.d_learning_rate *real_k_testvtrain @ (noisy_real_labels*(1-(d_real)) - (1-noisy_real_labels)*(d_real))/self.minibatch_size \
                 + self.d_learning_rate *fake_k_testvtrain @ (noisy_fake_labels*(1-(d_fake)) - (1-noisy_fake_labels)*(d_fake))/self.minibatch_size \
                 + self.momentum * (self.d_kernel_output - self.d_prev_kernel_output)
          
          self.d_prev_kernel_output = self.d_kernel_output
          self.d_kernel_output = temp
          
        for g_update in range(self.g_iterations):
          gen_samples = self.g_sampler(self.minibatch_size)
          fake_data = self.G(circle_transform(gen_samples).reshape(self.minibatch_size, 2).cuda())
          noisy_labels = np.random.uniform(low = 0, high = 0.1, size = (self.minibatch_size,1)).astype(np.float32)
          flipped_idx = np.random.choice(np.arange(self.minibatch_size), size=1)
          noisy_labels[flipped_idx] = 1 - noisy_labels[flipped_idx]
          d_fake = self._generator_train_step(fake_data, noisy_labels)
          
          loss = sum(d_fake)/self.minibatch_size
          d_fake = d_fake.cpu().detach().numpy()
          fake_k_testvtrain = kernel_mats(self.G_copy, gen_samples,  self.g_test_data, n_train = self.minibatch_size, kernels = 'testvtrain').cpu().detach().numpy()
          d_prime_vec = torch.autograd.grad(loss,fake_data, only_inputs=True, retain_graph=True)[0].cpu().numpy() # vector of derivatives of D wrt each fake output of G
          temp = self.g_kernel_output + self.g_learning_rate*fake_k_testvtrain \
                 @ (d_prime_vec * (noisy_labels/d_fake - (1-noisy_labels)/(1-d_fake)))\
                 + self.momentum * (self.g_kernel_output - self.g_prev_kernel_output)
          self.g_prev_kernel_output = self.g_kernel_output
          self.g_kernel_output = temp
          
      j = i-2

    self.fig.suptitle('Epoch {}'.format(self.epochs_per_frame *j))
    g_current = self.G(circle_transform(self.g_test_data).cuda()).cpu().detach().numpy() 
    self.g_backprop_line.set_data(self.g_test_data.numpy(), g_current)
    self.d_backprop_line.set_data(self.d_test_data.numpy(), torch.sigmoid(self.D(self.d_test_data.cuda())).cpu().detach().numpy())
    
    self.g_kernel_line.set_data(self.g_test_data.numpy(), self.g_kernel_output)
    self.d_kernel_line.set_data(self.d_test_data.numpy(), expit(self.d_kernel_output))
    
    np.seterr(divide='ignore', invalid='ignore')
    
    self.ax[1,0].clear()
    self.ax[1,0].set_title('Approx Backprop Histogram')
    self.ax[1,0].set_ylim((0,0.8))
    self.ax[1,0].set_xlim((self.data_mean - 4 * self.data_stddev,self.data_mean + 4 * self.data_stddev))
    self.ax[1,0].hist(g_current, bins = np.linspace(data_mean - 3 * data_stddev, data_mean + 3 * data_stddev, 20), density = True)
    x = np.linspace(self.data_mean - 3*self.data_stddev, self.data_mean + 3*self.data_stddev, 100)
    self.ax[1,0].plot(x, norm.pdf(x,self.data_mean, self.data_stddev), alpha = 0.7, color = 'c', label = 'True density')
    self.ax[1,0].legend()
    
    self.ax[1,1].clear()
    self.ax[1,1].set_title('Approx Kernel Histogram')
    self.ax[1,1].set_ylim((0,0.8))
    self.ax[1,1].set_xlim((self.data_mean - 4 * self.data_stddev,self.data_mean + 4 * self.data_stddev))
    self.ax[1,1].hist(self.g_kernel_output, bins = np.linspace(data_mean - 3 * data_stddev, data_mean + 3 * data_stddev, 20), density = True)
    self.ax[1,1].plot(x, norm.pdf(x,self.data_mean, self.data_stddev), alpha = 0.7, color = 'c', label = 'True density')
    self.ax[1,1].legend()


    return(self.g_backprop_line, self.d_backprop_line, self.g_kernel_line, self.d_kernel_line, )


In [0]:
data_mean = 5
data_stddev = 1.25
fig, ax = plt.subplots(nrows = 2, ncols = 2, figsize = (13,13))
plt.subplots_adjust(wspace=0.15,hspace=0.15)
plt.close()

ax[0,0].set_xlim((-np.pi, np.pi))
ax[0,0].set_ylim((data_mean - 4 * data_stddev, data_mean + 4 * data_stddev))
ax[0,1].set_xlim((data_mean - 4 * data_stddev, data_mean + 4 * data_stddev))
ax[0,1].set_ylim((0,1))
ax[0,0].set_xlabel('$z$')
ax[0,0].set_ylabel('$G_{ \\theta}(z)$')
ax[0,1].set_xlabel('$x$')
ax[0,1].set_ylabel('$D_{ \phi}(x)$')
ax[0,0].set_title('Generator')
ax[0,1].set_title('Discriminator')
ax[1,0].set_ylabel('Density')
ax[1,1].set_ylabel('Density')
ax[0,1].axhline(0.5, linestyle='--', color = 'darkgrey', alpha = 0.9, linewidth = 1)

line0, = ax[0,0].plot([], [], lw=1, linestyle = '--', color = 'darkmagenta', label = 'Backprop')
line0a, = ax[0,0].plot([], [], lw=1, color = 'darkmagenta', label = 'Kernel GD')
line1, = ax[0,1].plot([], [], lw=1, color = 'r', linestyle = '--', label = 'Backprop')
line1a, = ax[0,1].plot([], [], lw=1, color = 'r', label = 'Kernel GD')

#hist0 = ax[1,0].hist([], bins = np.linspace(data_mean - 2 * data_stddev, data_mean + 2 * data_stddev, 20), density = True)
#hist1 = ax[1,0].hist([], bins = np.linspace(data_mean - 2 * data_stddev, data_mean + 2 * data_stddev, 20), density = True)
line_tuple = (line0, line1, line0a, line1a)#, hist0, hist1)

ax[0,0].legend(loc = 'upper left')
ax[0,1].legend(loc = 'upper left')
g_hidden_size = 500
d_hidden_size = 500      
generator_activation_function = torch.tanh
discriminator_activation_function = torch.sigmoid
G = Generator(input_size = 2,
                  hidden_size=g_hidden_size,
                  output_size = 1,
                  f=generator_activation_function).cuda()   
D = Discriminator(input_size = 1,
                  hidden_size=d_hidden_size,
                  output_size = 1,
                  f=discriminator_activation_function).cuda()
#D = nn.Sequential(D_no_sig,nn.Sigmoid())
GAN = GANimation(G,D,data_mean = data_mean, data_stddev = data_stddev, dis_iterations=10, line_tuple = line_tuple, ax = ax, fig = fig, epochs_per_frame=10)


start=time.time()
anim = animation.FuncAnimation(fig, GAN.plot_train_step, frames = 150, interval = 150, blit = True)
rc('animation', html='jshtml')
anim.save('anim_gan_good.mp4')
files.download('anim_gan_good.mp4')

In [0]:
end = time.time() - start

In [0]:
end

88.37035846710205