In [0]:
import time
import itertools

import numpy as regurlar_np
import numpy.random as npr

import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

In [0]:
def data_to_numpy(dataloader, flatten=True):
    X = []
    y = []
    for batch_id, (cur_X, cur_y) in enumerate(dataloader):
        X.extend(cur_X.numpy())
        y.extend(cur_y.numpy())
    X = np.asarray(X)
    y = np.asarray(y)
    if flatten:
        l = len(X)
        X = X.reshape(l, -1)
    return X, y

In [0]:
def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

In [0]:
def fashionMnist():
  torch.manual_seed(0)

  D = 28
  num_classes = 10

  train_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=True, download=True,
          transform=transforms.ToTensor()),
    batch_size=4096, shuffle=True, pin_memory=True)

  test_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=False, transform=transforms.ToTensor()),
    batch_size=4096, shuffle=True, pin_memory=True)
  
  train_images, train_labels = data_to_numpy(train_loader)
  test_images,  test_labels  = data_to_numpy(test_loader)

  train_labels = _one_hot(train_labels, num_classes)
  test_labels  = _one_hot(test_labels,  num_classes)
  return train_images, train_labels, test_images, test_labels

In [5]:
%%time

train_images, train_labels, test_images, test_labels = fashionMnist()
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

CPU times: user 53.6 s, sys: 17.9 s, total: 1min 11s
Wall time: 55 s


In [0]:
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)

In [7]:
init_random_params, predict = stax.serial(
    Dense(1024), Relu,
    Dense(1024), Relu,
    Dense(10), LogSoftmax)

rng = random.PRNGKey(0)

step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9

print(train_images.shape)
print(train_labels.shape)
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    opt_state = update(next(itercount), opt_state, next(batches))
  epoch_time = time.time() - start_time

  params = get_params(opt_state)
  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

(60000, 784)
(60000, 10)

Starting training...
Epoch 0 in 6.60 sec
Training set accuracy 0.781166672706604
Test set accuracy 0.7704000473022461
Epoch 1 in 3.53 sec
Training set accuracy 0.8212500214576721
Test set accuracy 0.807200014591217
Epoch 2 in 3.44 sec
Training set accuracy 0.831933319568634
Test set accuracy 0.8203000426292419
Epoch 3 in 3.53 sec
Training set accuracy 0.8410166501998901
Test set accuracy 0.8296000361442566
Epoch 4 in 3.46 sec
Training set accuracy 0.8459500074386597
Test set accuracy 0.835900068283081
Epoch 5 in 3.30 sec
Training set accuracy 0.8499000072479248
Test set accuracy 0.836400032043457
Epoch 6 in 3.56 sec
Training set accuracy 0.8526666760444641
Test set accuracy 0.8388000130653381
Epoch 7 in 3.41 sec
Training set accuracy 0.8567166924476624
Test set accuracy 0.8438000679016113
Epoch 8 in 3.39 sec
Training set accuracy 0.8595499992370605
Test set accuracy 0.8446000218391418
Epoch 9 in 3.50 sec
Training set accuracy 0.8629000186920166
Test set accura