<a href="https://colab.research.google.com/github/neural-tangents/neural-tangents/blob/master/notebooks/function_space_linearization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Copyright 2019 The Neural Tangents Authors.  All rights reserved.

#### Import & Utils

Install JAX, Tensorflow Datasets, and Neural Tangents. 

The first line specifies the version of jaxlib that we would like to import. Note, that "cp36" species the version of python (version 3.6) used by JAX. Make sure your colab kernel matches this version.

In [0]:
!pip install -q tensorflow-datasets
!pip install -q git+https://www.github.com/neural-tangents/neural-tangents

Import libraries

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from jax.api import jit
from jax.api import grad
from jax import random

import jax.numpy as np
from jax.experimental.stax import logsoftmax
from jax.experimental import optimizers

import tensorflow_datasets as tfds



import neural_tangents as nt
from neural_tangents import stax

Define helper functions for processing data and defining a vanilla momentum optimizer

In [0]:
def process_data(data_chunk):
  """Flatten the images and one-hot encode the labels."""
  image, label = data_chunk['image'], data_chunk['label']
  
  samples = image.shape[0]
  image = np.array(np.reshape(image, (samples, -1)), dtype=np.float32)
  image = (image - np.mean(image)) / np.std(image)
  label = np.eye(10)[label]
  
  return {'image': image, 'label': label}

In [0]:
@optimizers.optimizer
def momentum(learning_rate, momentum=0.9):
  """A standard momentum optimizer for testing.

  Different from `jax.experimental.optimizers.momentum` (Nesterov).
  """
  learning_rate = optimizers.make_schedule(learning_rate)
  def init_fn(x0):
    v0 = np.zeros_like(x0)
    return x0, v0
  def update_fn(i, g, state):
    x, velocity = state
    velocity = momentum * velocity + g
    x = x - learning_rate(i) * velocity
    return x, velocity
  def get_params(state):
    x, _ = state
    return x
  return init_fn, update_fn, get_params


# Function Space Linearization

Create MNIST data pipeline using TensorFlow Datasets.

In [0]:
dataset_size = 64

train = tfds.load('mnist', split=tfds.Split.TRAIN, batch_size=dataset_size)
train = process_data(next(tfds.as_numpy(train)))

test = tfds.load('mnist', split=tfds.Split.TEST, batch_size=dataset_size)
test = process_data(next(tfds.as_numpy(test)))

Setup some experiment parameters.

In [0]:
learning_rate = 1e0
training_time = 1000.0
print_every = 100.0

Create a Fully-Connected Network.

In [0]:
init_fn, f, _ = stax.serial(
    stax.Dense(2048, 1., 0.05), 
    stax.Erf(),
    stax.Dense(10, 1., 0.05))

key = random.PRNGKey(0)
_, params = init_fn(key, (-1, 784))

Construct the NTK.

In [0]:
ntk = nt.batch(nt.empirical_ntk_fn(f), batch_size=16, device_count=0)

g_dd = ntk(train['image'], None, params)
g_td = ntk(test['image'], train['image'], params)

Now that we have the NTK and a network we can compare against a number of different dynamics. Remember to reinitialize the network and NTK if you want to try a different dynamics.

## Gradient Descent, MSE Loss

Create a optimizer and initialize it.

In [0]:
opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
state = opt_init(params)

Create an MSE loss and a gradient.

In [0]:
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create an MSE predictor and compute the function space values of the network at initialization.

In [0]:
predictor = nt.predict.gradient_descent_mse(g_dd, train['label'])
fx_train = f(params, train['image'])

Train the network.

In [0]:
print ('Time\tLoss\tLinear Loss')
print_every_step = int(print_every // learning_rate)

X, Y = train['image'], train['label']

for i in range(int(training_time // learning_rate)):
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)
  
  if i % print_every_step == 0:
    t = i * learning_rate
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(predictor(t, fx_train), Y)
    print('{}\t{:.4f}\t{:.4f}'.format(t, exact_loss, linear_loss))
    

Time	Loss	Linear Loss
0.0	0.2506	0.2506
100.0	0.1119	0.1120
200.0	0.0771	0.0771
300.0	0.0586	0.0585
400.0	0.0468	0.0467
500.0	0.0386	0.0385
600.0	0.0325	0.0324
700.0	0.0278	0.0276
800.0	0.0240	0.0239
900.0	0.0210	0.0208


## Gradient Descent, Cross Entropy Loss

Create a optimizer and initialize it.

In [0]:
opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
state = opt_init(params)

Create an Cross Entropy loss and a gradient.

In [0]:
loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create a Gradient Descent predictor and compute the function space values of the network at initialization.

In [0]:
predictor = nt.predict.gradient_descent(g_dd, train['label'], loss)
fx_train = f(params, train['image'])

Train the network.

In [0]:
print ('Time\tLoss\tLinear Loss')
print_every_step = int(print_every // learning_rate)

X, Y = train['image'], train['label']

for i in range(int(training_time // learning_rate)):
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)
  
  if i % print_every_step == 0:
    t = i * learning_rate
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(predictor(t, fx_train), Y)
    print('{:.0f}\t{:.4f}\t{:.4f}'.format(t, exact_loss, linear_loss))
    

Time	Loss	Linear Loss
0	0.1647	0.1647
100	0.1437	0.1437
200	0.1268	0.1270
300	0.1133	0.1137
400	0.1021	0.1028
500	0.0927	0.0938
600	0.0848	0.0861
700	0.0779	0.0794
800	0.0719	0.0736
900	0.0666	0.0685


## Momentum, Cross Entropy Loss

Create a optimizer and initialize it.

In [0]:
opt_init, opt_apply, get_params = momentum(learning_rate, 0.9)
state = opt_init(params)

Create a Cross Entropy loss and a gradient.

In [0]:
loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create a momentum predictor and initialize it.

In [0]:
pred_init, predictor, get = nt.predict.momentum(
    g_dd, train['label'], loss, learning_rate)
fx_train = f(params, train['image'])
pred_state = pred_init(fx_train)

Train the network.

In [0]:
print ('Time\tLoss\tLinear Loss')
print_every_step = int(print_every // np.sqrt(learning_rate))

X, Y = train['image'], train['label']

for i in range(int(300.0 // np.sqrt(learning_rate))):
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)
  
  if i % print_every_step == 0:
    t = i * np.sqrt(learning_rate)
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(get(predictor(pred_state, t)), Y)
    print('{:.0f}\t{:.4f}\t{:.4f}'.format(t, exact_loss, linear_loss))
    

Time	Loss	Linear Loss
0	0.0620	0.0620
100	0.0357	0.0382
200	0.0233	0.0253
