https://github.com/tensorflow/models/tree/master/official/resnet

In [1]:
import collections
import time
import functools

import jax.numpy as np
import numpy.random as npr
import tensorflow as tf

from jax import jit, grad, random, lax
from jax.experimental import optimizers, stax
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
                                   FanOut, Flatten, GeneralConv, Identity,
                                   MaxPool, Relu, LogSoftmax, Dropout)
from tqdm import tqdm_notebook
from utils import get_ds_batches

In [2]:
key = random.PRNGKey(0)

num_epochs = 2
batch_size = 32
step_size = 1e-3
data_dir = '/projects/tfds'



In [3]:
def pad_layer(**fun_kwargs):
  pad_size = np.sum(fun_kwargs.get('padding_config'), axis=1)
  init_fun = lambda rng, input_shape: (tuple(np.sum((input_shape, pad_size), axis=2)), ())
  apply_fun = lambda params, inputs, **kwargs: lax.pad(inputs, **fun_kwargs)
  return init_fun, apply_fun


def ConvBlock(kernel_size, filters, strides=(2, 2)):
  ks = kernel_size
  filters1, filters2 = filters
  Main = stax.serial(
      Conv(filters1, (ks, ks), strides, padding='SAME'),
      BatchNorm(), Relu,
      Conv(filters2, (ks, ks), padding='SAME'),
      BatchNorm())
  Shortcut = stax.serial(
      Conv(filters2, (1, 1), strides, W_init=None),
      BatchNorm())
  return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)


def IdentityBlock(kernel_size, filters):
  ks = kernel_size
  filters1, filters2 = filters
  def make_main(input_shape):
    return stax.serial(
        Conv(filters1, (1, 1), padding='SAME'),
        BatchNorm(), Relu,
        Conv(filters2, (ks, ks), padding='SAME'),
        BatchNorm())
  Main = stax.shape_dependent(make_main)
  return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)


def ResNet20(num_classes):
  return stax.serial(
      pad_layer(padding_value=0.0, padding_config=((0, 0, 0), (1, 1, 0), (1, 1, 0), (0, 0, 0))),
      Conv(16, (3, 3)),
      BatchNorm(), Relu,
      ConvBlock(3, [16, 16], strides=(1, 1)),
      IdentityBlock(3, [16, 16]),
      IdentityBlock(3, [16, 16]),
      ConvBlock(3, [32, 32]),
      IdentityBlock(3, [32, 32]),
      IdentityBlock(3, [32, 32]),
      ConvBlock(3, [64, 64]),
      IdentityBlock(3, [64, 64]),
      IdentityBlock(3, [64, 64]),
      AvgPool((8, 8)), Flatten, Dense(num_classes), LogSoftmax)

In [4]:
def cross_entropy(logits, labels):
  return -np.mean(np.sum(logits * labels, axis=-1))


def loss_fun(params, batch, predict_fun, rng=None):
  inputs, labels = batch
  logits = predict_fun(params, inputs, rng=rng)
  return cross_entropy(logits, labels)


def accuracy(logits, labels):
  predicted_class = np.argmax(logits, axis=1)
  labels_class = np.argmax(labels, axis=1)
  return np.mean(predicted_class == labels_class)


def jit_update_fun(model_fun, loss, opt):
  opt_update, get_params = opt
  def update(i, opt_state, batch, rng=None):
    params = get_params(opt_state)
    grads = grad(loss_fun)(params, batch, model_fun, rng)
    return opt_update(i, grads, opt_state)
  return jit(update)


def jit_predict_fun(model_fun):
  def predict(params, inputs, rng=None):
    return jit(model_fun)(params, inputs, rng=rng)
  return predict

In [None]:
train_ds = get_ds_batches('cifar10', data_dir, 10, 'train', batch_size)
_, train_len, img_shape, num_classes = train_ds
test_ds = get_ds_batches('cifar10', data_dir, 10, 'test', batch_size)
_, test_len, _, _ = test_ds
input_shape = (batch_size,) + img_shape

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)

init_fun, predict_fun = ResNet20(10)
_, init_params = init_fun(key, input_shape)
opt_state = opt_init(init_params)

update_step = jit_update_fun(predict_fun, loss_fun, (opt_update, get_params))
predict_step = jit_predict_fun(predict_fun)

In [None]:
step = 0
for ep in range(num_epochs):
  train_batches, _, _, _ = get_ds_batches('cifar10', data_dir, 10, 'train', batch_size)
  start_time = time.time()
  for i, batch in tqdm_notebook(enumerate(train_batches), total=train_len):
    opt_state = update_step(step, opt_state, batch)
    if step % 100 == 0:
      inputs, labels = batch
      logits = predict_step(get_params(opt_state), inputs)
      print(f'Step: {step:d},  '
            f'Time/Step: {(time.time() - start_time) / (i + 1):.3f}s  '
            f'Loss: {cross_entropy(logits, labels):.5f}  '
            f'Acc: {accuracy(logits, labels):.3f}')
    step += 1
  trained_params = get_params(opt_state)
  
  test_metrics = collections.defaultdict(float)
  test_batches, _, _, _ = get_ds_batches('cifar10', data_dir, 10, 'test', batch_size)
  start_time = time.time()
  for i, batch in tqdm_notebook(enumerate(test_batches), total=test_len):
    inputs, labels = batch
    logits = predict_step(trained_params, inputs)
    test_metrics['loss'] += cross_entropy(logits, labels)
    test_metrics['acc'] += accuracy(logits, labels)
  
  print(f'Epoch: {ep:d}  '
        f'Time/Step: {(time.time() - start_time) / (i + 1):.3f}s  '
        f"Eval Loss: {test_metrics['loss'] / test_len:.5f}  "
        f"Eval Acc: {test_metrics['acc'] / test_len:.3f}")

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

Step: 0,  Time/Step: 52.984s  Loss: 3.00426  Acc: 0.094
Step: 100,  Time/Step: 0.536s  Loss: 2.68707  Acc: 0.125
Step: 200,  Time/Step: 0.396s  Loss: 2.42020  Acc: 0.188
Step: 300,  Time/Step: 0.350s  Loss: 2.23849  Acc: 0.156
Step: 400,  Time/Step: 0.320s  Loss: 2.05773  Acc: 0.188
Step: 500,  Time/Step: 0.302s  Loss: 2.13917  Acc: 0.188
Step: 600,  Time/Step: 0.290s  Loss: 1.96442  Acc: 0.312
