<a href="https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/function_space_linearization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2019 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

#### Import & Utils

Install JAX, Tensorflow Datasets, and Neural Tangents.

The first line specifies the version of jaxlib that we would like to import. Note, that "cp36" species the version of python (version 3.6) used by JAX. Make sure your colab kernel matches this version.

In [0]:
!pip install -q --upgrade pip
!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q tensorflow-datasets
!pip install -q git+https://www.github.com/google/neural-tangents

Import libraries

In [0]:
from jax import jit
from jax import grad
from jax import random

import jax.numpy as np
from jax.nn import log_softmax
from jax.example_libraries import optimizers

import tensorflow_datasets as tfds

import neural_tangents as nt
from neural_tangents import stax

Define helper functions for processing data and defining a vanilla momentum optimizer

In [0]:
def process_data(data_chunk):
  """Flatten the images and one-hot encode the labels."""
  image, label = data_chunk['image'], data_chunk['label']

  samples = image.shape[0]
  image = np.array(np.reshape(image, (samples, -1)), dtype=np.float32)
  image = (image - np.mean(image)) / np.std(image)
  label = np.eye(10)[label]

  return {'image': image, 'label': label}

In [0]:
@optimizers.optimizer
def momentum(learning_rate, momentum=0.9):
  """A standard momentum optimizer for testing.

  Different from `jax.example_libraries.optimizers.momentum` (Nesterov).
  """
  learning_rate = optimizers.make_schedule(learning_rate)
  def init_fn(x0):
    v0 = np.zeros_like(x0)
    return x0, v0
  def update_fn(i, g, state):
    x, velocity = state
    velocity = momentum * velocity + g
    x = x - learning_rate(i) * velocity
    return x, velocity
  def get_params(state):
    x, _ = state
    return x
  return init_fn, update_fn, get_params


# Function Space Linearization

Create MNIST data pipeline using TensorFlow Datasets.

In [0]:
dataset_size = 64

ds_train, ds_test = tfds.as_numpy(
    tfds.load('mnist:3.*.*', split=['train[:%d]' % dataset_size,
                                    'test[:%d]' % dataset_size],
              batch_size=-1)
)

train = process_data(ds_train)
test = process_data(ds_test)

Setup some experiment parameters.

In [0]:
learning_rate = 1e0
training_steps = np.arange(1000)
print_every = 100.0

Create a Fully-Connected Network.

In [0]:
init_fn, f, _ = stax.serial(
    stax.Dense(512, 1., 0.05),
    stax.Erf(),
    stax.Dense(10, 1., 0.05))

key = random.PRNGKey(0)
_, params = init_fn(key, (-1, 784))

Construct the NTK.

In [0]:
ntk = nt.batch(nt.empirical_ntk_fn(f, vmap_axes=0),
               batch_size=64, device_count=0)

g_dd = ntk(train['image'], None, params)
g_td = ntk(test['image'], train['image'], params)

Now that we have the NTK and a network we can compare against a number of different dynamics. Remember to reinitialize the network and NTK if you want to try a different dynamics.

## Gradient Descent, MSE Loss

Create a optimizer and initialize it.

In [0]:
opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
state = opt_init(params)

Create an MSE loss and a gradient.

In [0]:
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create an MSE predictor and compute the function space values of the network at initialization.

In [0]:
predictor = nt.predict.gradient_descent_mse(g_dd, train['label'],
                                            learning_rate=learning_rate)
fx_train = f(params, train['image'])

Train the network.

In [0]:
print ('Time\tLoss\tLinear Loss')

X, Y = train['image'], train['label']

predictions = predictor(training_steps, fx_train)

for i in training_steps:
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)

  if i % print_every == 0:
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(predictions[i], Y)
    print('{}\t{:.4f}\t{:.4f}'.format(i, exact_loss, linear_loss))

Time	Loss	Linear Loss
0	0.2444	0.2444
100	0.1231	0.1234
200	0.0854	0.0855
300	0.0652	0.0649
400	0.0523	0.0519
500	0.0434	0.0427
600	0.0367	0.0359
700	0.0315	0.0306
800	0.0273	0.0263
900	0.0239	0.0229


## Gradient Descent, Cross Entropy Loss

Create a optimizer and initialize it.

In [0]:
opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
state = opt_init(params)

Create an Cross Entropy loss and a gradient.

In [0]:
loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create a Gradient Descent predictor and compute the function space values of the network at initialization.

In [0]:
predictor = nt.predict.gradient_descent(loss, g_dd, train['label'], learning_rate=learning_rate)
fx_train = f(params, train['image'])

Train the network.

In [0]:
print ('Time\tLoss\tLinear Loss')

X, Y = train['image'], train['label']

predictions = predictor(training_steps, fx_train)

for i in training_steps:
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)

  if i % print_every == 0:
    t = i * learning_rate
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(predictions[i], Y)
    print('{:.0f}\t{:.4f}\t{:.4f}'.format(i, exact_loss, linear_loss))

Time	Loss	Linear Loss
0	0.1696	0.1696
100	0.1497	0.1493
200	0.1336	0.1329
300	0.1204	0.1195
400	0.1093	0.1083
500	0.0998	0.0987
600	0.0916	0.0906
700	0.0845	0.0835
800	0.0783	0.0773
900	0.0728	0.0719


## Momentum, Cross Entropy Loss

Create a optimizer and initialize it.

In [0]:
mass = 0.9
opt_init, opt_apply, get_params = momentum(learning_rate, mass)
state = opt_init(params)

Create a Cross Entropy loss and a gradient.

In [0]:
loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

Create a momentum predictor and initialize it.

In [0]:
predictor = nt.predict.gradient_descent(loss,
    g_dd, train['label'], learning_rate=learning_rate, momentum=mass)
fx_train = f(params, train['image'])

Train the network.

In [0]:
print ('Time\tLoss\tLinear Loss')

X, Y = train['image'], train['label']

predictions = predictor(training_steps, fx_train)

for i in training_steps:
  params = get_params(state)
  state = opt_apply(i, grad_loss(params, X, Y), state)

  if i % print_every == 0:
    exact_loss = loss(f(params, X), Y)
    linear_loss = loss(predictions[i], Y)
    print('{:.0f}\t{:.4f}\t{:.4f}'.format(i, exact_loss, linear_loss))


Time	Loss	Linear Loss
0	0.0680	0.0680
100	0.0399	0.0401
200	0.0262	0.0266
300	0.0191	0.0195
400	0.0148	0.0153
500	0.0120	0.0126
600	0.0101	0.0106
700	0.0086	0.0092
800	0.0075	0.0081
900	0.0067	0.0072
