In [0]:
import time
import itertools

import numpy.random as npr

import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

In [0]:
def data_to_numpy(dataloader):
    X = []
    y = []
    for batch_id, (cur_X, cur_y) in enumerate(dataloader):
        X.extend(cur_X.numpy())
        y.extend(cur_y.numpy())
    X = np.asarray(X)
    y = np.asarray(y)
    return X, y

def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

In [0]:
def cifar_10():
  torch.manual_seed(0)

  D = 32
  num_classes = 10

  torch.manual_seed(0)

  if torch.cuda.is_available():
      device = torch.device('cuda:0')
  else:
      device = torch.device('cpu')

  cifar10_stats = {
      "mean" : (0.4914, 0.4822, 0.4465),
      "std"  : (0.24705882352941178, 0.24352941176470588, 0.2615686274509804),
  }

  simple_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(cifar10_stats['mean'], cifar10_stats['std']),
  ])

  train_loader = torch.utils.data.DataLoader(
                    datasets.CIFAR10(root='./data', train=True, download=True, transform=simple_transform),
                batch_size=2048, shuffle=True, pin_memory=True)

  test_loader  = torch.utils.data.DataLoader(
                    datasets.CIFAR10(root='./data', train=False, download=True, transform=simple_transform),
                batch_size=2048, shuffle=True, pin_memory=True)
  
  train_images, train_labels = data_to_numpy(train_loader)
  test_images,  test_labels  = data_to_numpy(test_loader)

  train_images = np.transpose(train_images, (0, 2, 3, 1))
  test_images  = np.transpose(test_images , (0, 2, 3, 1))

  train_labels = _one_hot(train_labels, num_classes)
  test_labels  = _one_hot(test_labels,  num_classes)
  return train_images, train_labels, test_images, test_labels

In [4]:
%%time

train_images, train_labels, test_images, test_labels = cifar_10()
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

Files already downloaded and verified
Files already downloaded and verified
CPU times: user 54.7 s, sys: 17.3 s, total: 1min 12s
Wall time: 56.2 s


In [5]:
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

((50000, 32, 32, 3), (50000, 10), (10000, 32, 32, 3), (10000, 10))

In [0]:
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)

In [53]:
channels = 32
num_classes = 10

init_random_params, predict = stax.serial(
      stax.Conv(channels, (3, 3), padding='SAME'),                stax.BatchNorm(), stax.Relu,
      stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), stax.BatchNorm(), stax.Relu,
      stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), stax.BatchNorm(), stax.Relu,
      stax.Conv(channels, (3, 3), strides=(2,2), padding='SAME'), stax.BatchNorm(), stax.Relu,
      stax.AvgPool((1, 1)), stax.Flatten, 
      stax.Dense(num_classes), stax.LogSoftmax
)
rng = random.PRNGKey(0)

step_size = 0.05
num_epochs = 10
batch_size = 500
momentum_mass = 0.9

num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream_of(images, labels):
  assert len(images) == len(labels)
  rng = npr.RandomState(0)

  n = len(images)
  perm = rng.permutation(n)
  for i in range(n // batch_size):
    batch_idx = perm[i * batch_size:(i + 1) * batch_size]
    yield images[batch_idx], labels[batch_idx]

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

_, init_params = init_random_params(rng, (batch_size, 32, 32, 3))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()

  for batch in data_stream_of(train_images, train_labels):
    opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  train_accs = [accuracy(params, batch) for batch in data_stream_of(train_images, train_labels)]  
  train_acc  = np.average(train_accs)
  test_accs  = [accuracy(params, batch) for batch in data_stream_of(test_images, test_labels)]  
  test_acc   = np.average(test_accs)
  
  epoch_time = time.time() - start_time

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy {train_acc}")
  print(f"Test set accuracy {test_acc}")


Starting training...
Epoch 0 in 11.06 sec
Training set accuracy 0.541379988193512
Test set accuracy 0.5235000252723694
Epoch 1 in 9.74 sec
Training set accuracy 0.6169600486755371
Test set accuracy 0.5885000824928284
Epoch 2 in 9.74 sec
Training set accuracy 0.6624999642372131
Test set accuracy 0.621399998664856
Epoch 3 in 9.62 sec
Training set accuracy 0.6905400156974792
Test set accuracy 0.6450001001358032
Epoch 4 in 9.67 sec
Training set accuracy 0.7164599299430847
Test set accuracy 0.6626001000404358
Epoch 5 in 9.90 sec
Training set accuracy 0.7408199906349182
Test set accuracy 0.6811999678611755
Epoch 6 in 9.88 sec
Training set accuracy 0.7523201107978821
Test set accuracy 0.6846001148223877
Epoch 7 in 9.83 sec
Training set accuracy 0.7669000029563904
Test set accuracy 0.6910000443458557
Epoch 8 in 9.65 sec
Training set accuracy 0.76910001039505
Test set accuracy 0.6869000792503357
Epoch 9 in 9.68 sec
Training set accuracy 0.7783799767494202
Test set accuracy 0.6927000284194946
