##### Copyright 2019 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

#### Import & Utils

Install JAX, Tensorflow Datasets, and Neural Tangents. 

The first line specifies the version of jaxlib that we would like to import. Note, that "cp36" species the version of python (version 3.6) used by JAX. Make sure your colab kernel matches this version.

In [0]:
!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.12-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
!pip install -q tensorflow-datasets
!pip install -q git+https://www.github.com/google/neural-tangents

Import libraries

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from jax.api import jit
from jax.api import grad
from jax import random

import jax.numpy as np
from jax.experimental import stax
from jax.experimental import optimizers

import tensorflow_datasets as tfds



from neural_tangents import tangents
from neural_tangents import layers

Define helper functions for processing data and defining a vanilla momentum optimizer

In [0]:
def process_data(data_chunk):
  """Flatten the images and one-hot encode the labels."""
  image, label = data_chunk['image'], data_chunk['label']
  
  samples = image.shape[0]
  image = np.array(np.reshape(image, (samples, -1)), dtype=np.float32)
  image = (image - np.mean(image)) / np.std(image)
  label = np.eye(10)[label]
  
  return {'image': image, 'label': label}

In [0]:
@optimizers.optimizer
def momentum(learning_rate, momentum=0.9):
  """A standard momentum optimizer for testing.

  Different from `jax.experimental.optimizers.momentum` (Nesterov).
  """
  learning_rate = optimizers.make_schedule(learning_rate)
  def init_fun(x0):
    v0 = np.zeros_like(x0)
    return x0, v0
  def update_fun(i, g, state):
    x, velocity = state
    velocity = momentum * velocity + g
    x = x - learning_rate(i) * velocity
    return x, velocity
  def get_params(state):
    x, _ = state
    return x
  return init_fun, update_fun, get_params


# Function Space Linearization

Create MNIST data pipeline using TensorFlow Datasets.

In [0]:
dataset_size = 64

train = tfds.load('mnist', split=tfds.Split.TRAIN, batch_size=dataset_size)
train = process_data(next(tfds.as_numpy(train)))

test = tfds.load('mnist', split=tfds.Split.TEST, batch_size=dataset_size)
test = process_data(next(tfds.as_numpy(test)))

Setup some experiment parameters.

In [0]:
learning_rate = 1e0
training_time = 1000.0
print_every = 100.0

Create a Fully-Connected Network.

In [0]:
init_fn, f = stax.serial(
    layers.Dense(4096), 
    stax.Tanh,
    layers.Dense(10))

key = random.PRNGKey(0)
_, params = init_fn(key, (-1, 784))

Construct the NTK.

In [0]:
theta = tangents.ntk(f, batch_size=16)

g_dd = theta(params, train['image'])
g_td = theta(params, test['image'], train['image'])

Now that we have the NTK and a network we can compare against a number of different dynamics. Remember to reinitialize the network and NTK if you want to try a different dynamics.

## Gradient Descent, MSE Loss

Create a optimizer and initialize it.

In [0]:
opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
state = opt_init(params)

Create an MSE loss and a gradient.

In [0]:
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create an MSE predictor and compute the function space values of the network at initialization.

In [0]:
predictor = tangents.analytic_mse_predictor(g_dd, train['label'])
fx_train = f(params, train['image'])

Train the network.

In [32]:
print ('Time\tLoss\tLinear Loss')
print_every_step = int(print_every // learning_rate)

X, Y = train['image'], train['label']

for i in range(int(training_time // learning_rate)):
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)
  
  if i % print_every_step == 0:
    t = i * learning_rate
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(predictor(fx_train, t), Y)
    print('{}\t{:.4f}\t{:.4f}'.format(t, exact_loss, linear_loss))
    

Time	Loss	Linear Loss
0.0	0.2234	0.2234
100.0	0.0997	0.0998
200.0	0.0756	0.0756
300.0	0.0600	0.0601
400.0	0.0492	0.0492
500.0	0.0413	0.0413
600.0	0.0354	0.0353
700.0	0.0307	0.0307
800.0	0.0269	0.0269
900.0	0.0239	0.0238


## Gradient Descent, Cross Entropy Loss

Create a optimizer and initialize it.

In [0]:
opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
state = opt_init(params)

Create an Cross Entropy loss and a gradient.

In [0]:
loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create a Gradient Descent predictor and compute the function space values of the network at initialization.

In [0]:
predictor = tangents.gradient_descent_predictor(g_dd, train['label'], loss)
fx_train = f(params, train['image'])

Train the network.

In [36]:
print ('Time\tLoss\tLinear Loss')
print_every_step = int(print_every // learning_rate)

X, Y = train['image'], train['label']

for i in range(int(training_time // learning_rate)):
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)
  
  if i % print_every_step == 0:
    t = i * learning_rate
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(predictor(fx_train, t), Y)
    print('{:.0f}\t{:.4f}\t{:.4f}'.format(t, exact_loss, linear_loss))
    

Time	Loss	Linear Loss
0	0.1701	0.1701
100	0.1506	0.1507
200	0.1354	0.1357
300	0.1232	0.1238
400	0.1132	0.1141
500	0.1047	0.1059
600	0.0973	0.0987
700	0.0908	0.0925
800	0.0851	0.0868
900	0.0799	0.0818


## Momentum, Cross Entropy Loss

Create a optimizer and initialize it.

In [0]:
opt_init, opt_apply, get_params = momentum(learning_rate, 0.9)
state = opt_init(params)

Create a Cross Entropy loss and a gradient.

In [0]:
loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create a momentum predictor and initialize it.

In [0]:
pred_init, predictor, get = tangents.momentum_predictor(
    g_dd, train['label'], loss, learning_rate)
fx_train = f(params, train['image'])
pred_state = pred_init(fx_train)

Train the network.

In [40]:
print ('Time\tLoss\tLinear Loss')
print_every_step = int(print_every // np.sqrt(learning_rate))

X, Y = train['image'], train['label']

for i in range(int(300.0 // np.sqrt(learning_rate))):
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)
  
  if i % print_every_step == 0:
    t = i * np.sqrt(learning_rate)
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(get(predictor(pred_state, t)), Y)
    print('{:.0f}\t{:.4f}\t{:.4f}'.format(t, exact_loss, linear_loss))
    

Time	Loss	Linear Loss
0	0.0753	0.0753
100	0.0471	0.0496
200	0.0321	0.0343
