In [1]:
!pip install neural-tangents



In [0]:
import time
import itertools

import numpy.random as npr

import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from  jax.nn import log_softmax

from jax.experimental import optimizers
import jax.experimental.stax as jax_stax
import neural_tangents.stax as nt_stax

import neural_tangents

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

In [0]:
def data_to_numpy(dataloader):
    X = []
    y = []
    for batch_id, (cur_X, cur_y) in enumerate(dataloader):
        X.extend(cur_X.numpy())
        y.extend(cur_y.numpy())
    X = np.asarray(X)
    y = np.asarray(y)
    return X, y

def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

In [0]:
def cifar_10():
  torch.manual_seed(0)

  D = 32
  num_classes = 10

  torch.manual_seed(0)

  if torch.cuda.is_available():
      device = torch.device('cuda:0')
  else:
      device = torch.device('cpu')

  cifar10_stats = {
      "mean" : (0.4914, 0.4822, 0.4465),
      "std"  : (0.24705882352941178, 0.24352941176470588, 0.2615686274509804),
  }

  simple_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(cifar10_stats['mean'], cifar10_stats['std']),
  ])

  train_loader = torch.utils.data.DataLoader(
                    datasets.CIFAR10(root='./data', train=True, download=True, transform=simple_transform),
                batch_size=2048, shuffle=True, pin_memory=True)

  test_loader  = torch.utils.data.DataLoader(
                    datasets.CIFAR10(root='./data', train=False, download=True, transform=simple_transform),
                batch_size=2048, shuffle=True, pin_memory=True)
  
  train_images, train_labels = data_to_numpy(train_loader)
  test_images,  test_labels  = data_to_numpy(test_loader)

  train_images = np.transpose(train_images, (0, 2, 3, 1))
  test_images  = np.transpose(test_images , (0, 2, 3, 1))

  train_labels = _one_hot(train_labels, num_classes)
  test_labels  = _one_hot(test_labels,  num_classes)
  return train_images, train_labels, test_images, test_labels

In [5]:
%%time

train_images, train_labels, test_images, test_labels = cifar_10()
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

Files already downloaded and verified
Files already downloaded and verified
CPU times: user 51.2 s, sys: 17.2 s, total: 1min 8s
Wall time: 53.4 s


In [6]:
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

((50000, 32, 32, 3), (50000, 10), (10000, 32, 32, 3), (10000, 10))

In [0]:
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(log_softmax(preds, axis=1) * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

rng_state = npr.RandomState(0)

def data_stream_of(images, labels, batch_size=500, batch_limit=None):
  assert len(images) == len(labels)
  rng = npr.RandomState(0)

  n = len(images)
  perm = rng.permutation(n)
  for i in range(n // batch_size):
    if (batch_limit is not None) and i >= batch_limit:
      break
    batch_idx = perm[i * batch_size:(i + 1) * batch_size]
    yield images[batch_idx], labels[batch_idx]

In [8]:
channels = 32
num_classes = 10

init_random_params, predict = jax_stax.serial(
      jax_stax.Conv(channels, (3, 3), padding='SAME'),                jax_stax.BatchNorm(), jax_stax.Relu,
      jax_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.BatchNorm(), jax_stax.Relu,
      jax_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.BatchNorm(), jax_stax.Relu,
      jax_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.BatchNorm(), jax_stax.Relu,
      jax_stax.AvgPool((1, 1)),    jax_stax.Flatten, 
      jax_stax.Dense(num_classes)
)
rng = random.PRNGKey(0)

step_size = 0.05
num_epochs = 2
batch_size = 500
momentum_mass = 0.9

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

_, init_params = init_random_params(rng, (batch_size, 32, 32, 3))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()

  for batch in data_stream_of(train_images, train_labels):
    opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  train_accs = [accuracy(params, batch) for batch in data_stream_of(train_images, train_labels)]  
  train_acc  = np.average(train_accs)
  test_accs  = [accuracy(params, batch) for batch in data_stream_of(test_images, test_labels)]  
  test_acc   = np.average(test_accs)
  
  epoch_time = time.time() - start_time

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy {train_acc}")
  print(f"Test set accuracy {test_acc}")


Starting training...
Epoch 0 in 15.22 sec
Training set accuracy 0.5342999696731567
Test set accuracy 0.5160000324249268
Epoch 1 in 6.35 sec
Training set accuracy 0.6130399703979492
Test set accuracy 0.5840000510215759


In [9]:
channels = 32
num_classes = 10

init_random_params, predict = jax_stax.serial(
      jax_stax.Conv(channels, (3, 3), padding='SAME'),                jax_stax.Relu,
      jax_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.Relu,
      jax_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.Relu,
      jax_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), jax_stax.Relu,
      jax_stax.AvgPool((1, 1)),    jax_stax.Flatten, 
      jax_stax.Dense(num_classes)
)
rng = random.PRNGKey(0)

step_size = 0.05
num_epochs = 2
momentum_mass = 0.9

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

_, init_params = init_random_params(rng, (batch_size, 32, 32, 3))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()

  for batch in data_stream_of(train_images, train_labels):
    opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  train_accs = [accuracy(params, batch) for batch in data_stream_of(train_images, train_labels)]  
  train_acc  = np.average(train_accs)
  test_accs  = [accuracy(params, batch) for batch in data_stream_of(test_images, test_labels)]  
  test_acc   = np.average(test_accs)
  
  epoch_time = time.time() - start_time

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy {train_acc}")
  print(f"Test set accuracy {test_acc}")


Starting training...
Epoch 0 in 3.88 sec
Training set accuracy 0.46296000480651855
Test set accuracy 0.46550002694129944
Epoch 1 in 3.19 sec
Training set accuracy 0.5372400283813477
Test set accuracy 0.5246999859809875


In [10]:
channels = 32
num_classes = 10

init_random_params, predict, _ = nt_stax.serial(
      nt_stax.Conv(channels, (3, 3), padding='SAME'),                nt_stax.Relu(),
      nt_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), nt_stax.Relu(),
      nt_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), nt_stax.Relu(),
      nt_stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), nt_stax.Relu(),
      nt_stax.AvgPool((1, 1)),   nt_stax.Flatten(), 
      nt_stax.Dense(num_classes)
)
rng = random.PRNGKey(0)

step_size = 10.
num_epochs = 10
batch_size = 500
momentum_mass = 0.9

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

_, init_params = init_random_params(rng, (batch_size, 32, 32, 3))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()

  for batch in data_stream_of(train_images, train_labels):
    opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  train_accs = [accuracy(params, batch) for batch in data_stream_of(train_images, train_labels)]  
  train_acc  = np.average(train_accs)
  test_accs  = [accuracy(params, batch) for batch in data_stream_of(test_images, test_labels)]  
  test_acc   = np.average(test_accs)
  
  epoch_time = time.time() - start_time

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy {train_acc}")
  print(f"Test set accuracy {test_acc}")


Starting training...
Epoch 0 in 5.73 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 1 in 5.00 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 2 in 4.98 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 3 in 5.13 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 4 in 4.95 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 5 in 4.91 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 6 in 4.89 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 7 in 4.89 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 8 in 4.89 sec
Training set accuracy 0.0
Test set accuracy 0.0
Epoch 9 in 4.95 sec
Training set accuracy 0.0
Test set accuracy 0.0


In [25]:
num_classes = 10

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = nt_stax.serial(
      nt_stax.Relu(), nt_stax.Conv(channels, (3, 3), strides, padding='SAME'),
      nt_stax.Relu(), nt_stax.Conv(channels, (3, 3), padding='SAME')
  )
  Shortcut = nt_stax.Identity() if not channel_mismatch else nt_stax.Conv(channels, (3, 3), strides, padding='SAME')
  return nt_stax.serial(nt_stax.FanOut(2),
                        nt_stax.parallel(Main, Shortcut),
                        nt_stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1)):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return nt_stax.serial(*blocks)

def WideResnet(num_classes, num_channels=32, block_size=1):
  return nt_stax.serial(
      nt_stax.Conv(num_channels, (3, 3), padding='SAME'),
      WideResnetGroup(block_size, num_channels),
      WideResnetGroup(block_size, num_channels, (2, 2)),
      WideResnetGroup(block_size, num_channels, (2, 2)),
      nt_stax.Relu(),
      nt_stax.AvgPool((1, 1)),
      nt_stax.Flatten(),
      nt_stax.Dense(num_classes)
  )

init_random_params, predict, _ = WideResnet(num_classes)

rng = random.PRNGKey(0)

step_size = 10.
num_epochs = 10
momentum_mass = 0.9

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

_, init_params = init_random_params(rng, (batch_size, 32, 32, 3))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()

  for batch in data_stream_of(train_images, train_labels):
    opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  train_accs = [accuracy(params, batch) for batch in data_stream_of(train_images, train_labels, batch_limit=4)]  
  train_acc  = np.average(train_accs)
  test_accs  = [accuracy(params, batch) for batch in data_stream_of(test_images, test_labels, batch_limit=4)]  
  test_acc   = np.average(test_accs)
  
  epoch_time = time.time() - start_time

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy {train_acc}")
  print(f"Test set accuracy {test_acc}")


Starting training...
Epoch 0 in 5.48 sec
Training set accuracy 0.4625000059604645
Test set accuracy 0.4825000464916229
Epoch 1 in 5.48 sec
Training set accuracy 0.5559999942779541
Test set accuracy 0.5355000495910645
Epoch 2 in 5.46 sec
Training set accuracy 0.5850000381469727
Test set accuracy 0.562999963760376
Epoch 3 in 5.45 sec
Training set accuracy 0.6270000338554382
Test set accuracy 0.5889999866485596
Epoch 4 in 5.47 sec
Training set accuracy 0.6610000133514404
Test set accuracy 0.6050000190734863
Epoch 5 in 5.49 sec
Training set accuracy 0.6920000314712524
Test set accuracy 0.6365000605583191
Epoch 6 in 5.53 sec
Training set accuracy 0.7000000476837158
Test set accuracy 0.6355000138282776
Epoch 7 in 5.55 sec
Training set accuracy 0.7105000019073486
Test set accuracy 0.6390000581741333
Epoch 8 in 5.51 sec
Training set accuracy 0.7260000109672546
Test set accuracy 0.6405000686645508
Epoch 9 in 5.48 sec
Training set accuracy 0.722000002861023
Test set accuracy 0.6260000467300415
