In [1]:
!pip install neural-tangents

Collecting neural-tangents
[?25l  Downloading https://files.pythonhosted.org/packages/69/e3/c191dd23f6a15199902157557b3ac59427673c1f5f0bc06580dca8003fe5/neural_tangents-0.1.9-py2.py3-none-any.whl (77kB)
[K     |████▏                           | 10kB 19.3MB/s eta 0:00:01[K     |████████▍                       | 20kB 1.7MB/s eta 0:00:01[K     |████████████▋                   | 30kB 2.5MB/s eta 0:00:01[K     |████████████████▉               | 40kB 1.7MB/s eta 0:00:01[K     |█████████████████████           | 51kB 2.1MB/s eta 0:00:01[K     |█████████████████████████▎      | 61kB 2.5MB/s eta 0:00:01[K     |█████████████████████████████▌  | 71kB 2.9MB/s eta 0:00:01[K     |████████████████████████████████| 81kB 2.5MB/s 
Collecting frozendict
  Downloading https://files.pythonhosted.org/packages/4e/55/a12ded2c426a4d2bee73f88304c9c08ebbdbadb82569ebdd6a0c007cfd08/frozendict-1.2.tar.gz
Building wheels for collected packages: frozendict
  Building wheel for frozendict (setup.py) ..

## Imports

In [0]:
import time
import itertools

import numpy.random as npr

import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from  jax.nn import log_softmax

from jax.experimental import optimizers
import jax.experimental.stax as jax_stax
import neural_tangents.stax as nt_stax

import neural_tangents

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

In [0]:
def data_to_numpy(dataloader):
    X = []
    y = []
    for batch_id, (cur_X, cur_y) in enumerate(dataloader):
        X.extend(cur_X.numpy())
        y.extend(cur_y.numpy())
    X = np.asarray(X)
    y = np.asarray(y)
    return X, y

def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

In [0]:
def cifar_10():
  torch.manual_seed(0)

  D = 32
  num_classes = 10

  torch.manual_seed(0)

  if torch.cuda.is_available():
      device = torch.device('cuda:0')
  else:
      device = torch.device('cpu')

  cifar10_stats = {
      "mean" : (0.4914, 0.4822, 0.4465),
      "std"  : (0.24705882352941178, 0.24352941176470588, 0.2615686274509804),
  }

  simple_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(cifar10_stats['mean'], cifar10_stats['std']),
  ])

  train_loader = torch.utils.data.DataLoader(
                    datasets.CIFAR10(root='./data', train=True, download=True, transform=simple_transform),
                batch_size=2048, shuffle=True, pin_memory=True)

  test_loader  = torch.utils.data.DataLoader(
                    datasets.CIFAR10(root='./data', train=False, download=True, transform=simple_transform),
                batch_size=2048, shuffle=True, pin_memory=True)
  
  train_images, train_labels = data_to_numpy(train_loader)
  test_images,  test_labels  = data_to_numpy(test_loader)

  train_images = np.transpose(train_images, (0, 2, 3, 1))
  test_images  = np.transpose(test_images , (0, 2, 3, 1))

  train_labels = _one_hot(train_labels, num_classes)
  test_labels  = _one_hot(test_labels,  num_classes)
  return train_images, train_labels, test_images, test_labels

In [5]:
%%time

train_images, train_labels, test_images, test_labels = cifar_10()
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified

CPU times: user 1min 22s, sys: 26.3 s, total: 1min 48s
Wall time: 1min 31s


In [6]:
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

((50000, 32, 32, 3), (50000, 10), (10000, 32, 32, 3), (10000, 10))

## Define training primitives

Note: The training code is based on the following example: https://github.com/google/jax/blob/master/examples/mnist_classifier.py.

In [0]:
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(log_softmax(preds, axis=1) * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

rng_state = npr.RandomState(0)

def data_stream_of(images, labels, batch_size=500, batch_limit=None):
  assert len(images) == len(labels)
  rng = npr.RandomState(0)

  n = len(images)
  perm = rng.permutation(n)
  for i in range(n // batch_size):
    if (batch_limit is not None) and i >= batch_limit:
      break
    batch_idx = perm[i * batch_size:(i + 1) * batch_size]
    yield images[batch_idx], labels[batch_idx]

## Train a small CNN in JAX with NTK parameterization

Here I do a few epochs to make sure that my training code works.

I do mix `jax.stax` with `neural_tangents.stax` because I want to use both BatchNorm and NTK parameterizaton. 

In [11]:
channels = 32
num_classes = 10

init_random_params, predict = jax_stax.serial(
      nt_stax.Conv(channels, (3, 3), padding='SAME'),                jax_stax.BatchNorm(), nt_stax.Relu(),
      nt_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.BatchNorm(), nt_stax.Relu(),
      nt_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.BatchNorm(), nt_stax.Relu(),
      nt_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.BatchNorm(), nt_stax.Relu(),
      nt_stax.AvgPool((1, 1)),   nt_stax.Flatten(), 
      nt_stax.Dense(num_classes), jax_stax.Identity,
)
rng = random.PRNGKey(0)

step_size = 10.
num_epochs = 10
batch_size = 500
momentum_mass = 0.9

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

_, init_params = init_random_params(rng, (batch_size, 32, 32, 3))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()

  for batch in data_stream_of(train_images, train_labels):
    opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  train_accs = [accuracy(params, batch) for batch in data_stream_of(train_images, train_labels)]  
  train_acc  = np.average(train_accs)
  test_accs  = [accuracy(params, batch) for batch in data_stream_of(test_images, test_labels)]  
  test_acc   = np.average(test_accs)
  
  epoch_time = time.time() - start_time

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy {train_acc}")
  print(f"Test set accuracy {test_acc}")


Starting training...
Epoch 0 in 13.36 sec
Training set accuracy 0.4674600064754486
Test set accuracy 0.46229997277259827
Epoch 1 in 11.62 sec
Training set accuracy 0.5352200269699097
Test set accuracy 0.5210000872612
Epoch 2 in 11.41 sec
Training set accuracy 0.5723000168800354
Test set accuracy 0.5527001023292542
Epoch 3 in 11.30 sec
Training set accuracy 0.5975600481033325
Test set accuracy 0.57340008020401
Epoch 4 in 11.43 sec
Training set accuracy 0.6060200333595276
Test set accuracy 0.5822001099586487
Epoch 5 in 11.45 sec
Training set accuracy 0.6427800059318542
Test set accuracy 0.607900083065033
Epoch 6 in 11.58 sec
Training set accuracy 0.648140013217926
Test set accuracy 0.6154000759124756
Epoch 7 in 11.45 sec
Training set accuracy 0.6545000076293945
Test set accuracy 0.6219999194145203
Epoch 8 in 11.48 sec
Training set accuracy 0.6666399836540222
Test set accuracy 0.6288000345230103
Epoch 9 in 11.50 sec
Training set accuracy 0.6705000400543213
Test set accuracy 0.63080000877

## Train a ResNet

In [0]:
num_classes = 10

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = jax_stax.serial(
      nt_stax.Conv(channels, (3, 3), strides, padding='SAME'), jax_stax.BatchNorm(), nt_stax.Relu(), 
      nt_stax.Conv(channels, (3, 3), padding='SAME'),          jax_stax.BatchNorm(), nt_stax.Relu(), 
      jax_stax.Identity
  )
  Shortcut = nt_stax.Identity() if not channel_mismatch else nt_stax.Conv(channels, (3, 3), strides, padding='SAME')
  return jax_stax.serial(jax_stax.FanOut(2),
                         jax_stax.parallel(Main, Shortcut),
                         jax_stax.FanInSum,
                         jax_stax.Identity)

def WideResnetGroup(n, channels, strides=(1, 1)):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return jax_stax.serial(*blocks)

def WideResnet(num_classes, num_channels=32, block_size=1):
  return jax_stax.serial(
      nt_stax.Conv(num_channels, (3, 3), padding='SAME'), jax_stax.BatchNorm(), nt_stax.Relu(),
      WideResnetGroup(block_size, num_channels),
      WideResnetGroup(block_size, num_channels, (2, 2)),
      WideResnetGroup(block_size, num_channels, (2, 2)),
      nt_stax.AvgPool((1, 1)),
      nt_stax.Flatten(),
      nt_stax.Dense(num_classes),
      jax_stax.Identity
  )

In [20]:
init_random_params, predict = WideResnet(num_classes)

rng = random.PRNGKey(0)

step_size = 10.
num_epochs = 10
momentum_mass = 0.9

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

_, init_params = init_random_params(rng, (batch_size, 32, 32, 3))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()

  for batch in data_stream_of(train_images, train_labels):
    opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  train_accs = [accuracy(params, batch) for batch in data_stream_of(train_images, train_labels, batch_limit=4)]  
  train_acc  = np.average(train_accs)
  test_accs  = [accuracy(params, batch) for batch in data_stream_of(test_images, test_labels, batch_limit=4)]  
  test_acc   = np.average(test_accs)
  
  epoch_time = time.time() - start_time

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy {train_acc}")
  print(f"Test set accuracy {test_acc}")


Starting training...
Epoch 0 in 10.01 sec
Training set accuracy 0.4259999990463257
Test set accuracy 0.44600000977516174
Epoch 1 in 5.16 sec
Training set accuracy 0.44450002908706665
Test set accuracy 0.4360000193119049
Epoch 2 in 5.13 sec
Training set accuracy 0.5074999928474426
Test set accuracy 0.5080000162124634
Epoch 3 in 5.18 sec
Training set accuracy 0.5400000214576721
Test set accuracy 0.5275000333786011
Epoch 4 in 5.23 sec
Training set accuracy 0.5705000162124634
Test set accuracy 0.5400000214576721
Epoch 5 in 5.19 sec
Training set accuracy 0.5830000042915344
Test set accuracy 0.5649999976158142
Epoch 6 in 5.16 sec
Training set accuracy 0.6200000047683716
Test set accuracy 0.5915000438690186
Epoch 7 in 5.16 sec
Training set accuracy 0.6640000343322754
Test set accuracy 0.6075000762939453
Epoch 8 in 5.14 sec
Training set accuracy 0.655500054359436
Test set accuracy 0.6155000329017639
Epoch 9 in 5.18 sec
Training set accuracy 0.6620000004768372
Test set accuracy 0.6234999895095

## Train a linearization of ResNet

Note: I have removed the BatchNorm layers because with them training didn't work. 

In [0]:


from jax.tree_util import tree_multimap
from jax.api import jvp
from jax.api import vjp

# copied from 
def linearize(f, params):
  """Returns a function `f_lin`, the first order taylor approximation to `f`.
  Example:
    >>> # Compute the MSE of the first order Taylor series of a function.
    >>> f_lin = linearize(f, params)
    >>> mse = np.mean((f(new_params, x) - f_lin(new_params, x)) ** 2)
  """
  @jit
  def f_lin(p, *args, **kwargs):
    dparams = tree_multimap(lambda x, y: x - y, p, params)
    f_params_x, proj = jvp(lambda param: f(param, *args, **kwargs),
                           (params,), (dparams,))
    return f_params_x + proj
  return f_lin

In [0]:
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = jax_stax.serial(
      nt_stax.Conv(channels, (3, 3), strides, padding='SAME'), nt_stax.Relu(), 
      nt_stax.Conv(channels, (3, 3), padding='SAME'),          nt_stax.Relu(), 
      jax_stax.Identity
  )
  Shortcut = nt_stax.Identity() if not channel_mismatch else nt_stax.Conv(channels, (3, 3), strides, padding='SAME')
  return jax_stax.serial(jax_stax.FanOut(2),
                         jax_stax.parallel(Main, Shortcut),
                         jax_stax.FanInSum,
                         jax_stax.Identity)

def WideResnetGroup(n, channels, strides=(1, 1)):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return jax_stax.serial(*blocks)

def WideResnet(num_classes, num_channels=32, block_size=1):
  return jax_stax.serial(
      nt_stax.Conv(num_channels, (3, 3), padding='SAME'), nt_stax.Relu(),
      WideResnetGroup(block_size, num_channels),
      WideResnetGroup(block_size, num_channels, (2, 2)),
      WideResnetGroup(block_size, num_channels, (2, 2)),
      nt_stax.AvgPool((1, 1)),
      nt_stax.Flatten(),
      nt_stax.Dense(num_classes),
      jax_stax.Identity
  )

In [0]:
num_classes = 10

init_random_params, predict = WideResnet(num_classes, num_channels=512)

rng = random.PRNGKey(0)

step_size = 1.
num_epochs = 100
momentum_mass = 0.9

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

_, init_params = init_random_params(rng, (batch_size, 32, 32, 3))
opt_state = opt_init(init_params)
itercount = itertools.count()

predict = linearize(predict, init_params) # !important: linearization

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()

  for batch in data_stream_of(train_images, train_labels, batch_size=100):
    opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  train_accs = [accuracy(params, batch) for batch in data_stream_of(train_images, train_labels, batch_size=100, batch_limit=20)]  
  train_acc  = np.average(train_accs)
  test_accs  = [accuracy(params, batch) for batch in data_stream_of(test_images, test_labels, batch_size=100, batch_limit=20)]  
  test_acc   = np.average(test_accs)
  
  epoch_time = time.time() - start_time

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy {train_acc}")
  print(f"Test set accuracy {test_acc}")


Starting training...
Epoch 0 in 297.33 sec
Training set accuracy 0.46000003814697266
Test set accuracy 0.40749993920326233
Epoch 1 in 293.79 sec
Training set accuracy 0.4874999523162842
Test set accuracy 0.4374999701976776
Epoch 2 in 293.76 sec
Training set accuracy 0.5044999718666077
Test set accuracy 0.4364999234676361
Epoch 3 in 293.63 sec
Training set accuracy 0.5345000624656677
Test set accuracy 0.4519999623298645
