# jax-bayes CIFAR10 Example --- Traditional ML Approach

## Set up the environment

In [2]:
#see https://github.com/google/jax#pip-installation
!pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl
!pip install --upgrade jax
!pip install git+https://github.com/deepmind/dm-haiku
!pip install git+https://github.com/jamesvuc/jax-bayes

Collecting jaxlib==0.1.51
[?25l  Downloading https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl (71.5MB)
[K     |████████████████████████████████| 71.5MB 42kB/s 
Installing collected packages: jaxlib
  Found existing installation: jaxlib 0.1.52
    Uninstalling jaxlib-0.1.52:
      Successfully uninstalled jaxlib-0.1.52
Successfully installed jaxlib-0.1.51
Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.75)
Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-qx61eemy
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-qx61eemy
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
  Created wheel for dm-haiku: filename=dm_haiku-0.0.2-cp36-none-any.whl size=289739 sha256=0ea4611f09ee7534f77a37f5f875814f9437bb2aa72d43f19d3b69d4892aabfb
  Stored 

In [3]:
import haiku as hk

import jax.numpy as jnp
from jax.experimental import optimizers
import jax

import sys, os, math, time
import numpy as np

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import tensorflow_datasets as tfds

## Build the dataset loader and CNN

In [4]:
def load_dataset(split, is_training, batch_size, repeat=True, seed=0):
  if repeat:
    ds = tfds.load('cifar10', split=split).cache().repeat()
  else:
    ds = tfds.load('cifar10', split=split).cache()
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=seed)
  ds = ds.batch(batch_size)
  return tfds.as_numpy(ds)

# build a 32-32-64-32 CNN with max-pooling 
# followed by a 128-10-n_classes MLP
class Net(hk.Module):
  def __init__(self, dropout=0.1, n_classes=10):
    super(Net, self).__init__()
    self.conv_stage = hk.Sequential([
      #block 1
      hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME'), 
      jax.nn.relu, 
      hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),
      # block 2
      hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME'), 
      jax.nn.relu, 
      hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),
      # block 3
      hk.Conv2D(64, kernel_shape=3, stride=1, padding='SAME'), 
      jax.nn.relu, 
      hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),
      # block 4
      hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME')
    ])

    self.mlp_stage = hk.Sequential([
      hk.Flatten(),
      hk.Linear(128), 
      jax.nn.relu, 
      hk.Linear(n_classes)
    ])

    self.p_dropout = dropout

  def __call__(self, x, use_dropout=True):
    x = self.conv_stage(x)
    
    dropout_rate = self.p_dropout if use_dropout else 0.0
    x = hk.dropout(hk.next_rng_key(), dropout_rate, x)

    return self.mlp_stage(x)

# standard normalization constants
mean_norm = jnp.array([[0.4914, 0.4822, 0.4465]])
std_norm = jnp.array([[0.247, 0.243, 0.261]])

#define the net-function 
def net_fn(batch, use_dropout):
  net = Net(dropout=0.0)
  x = batch['image']/255.0
  x = (x - mean_norm) / std_norm
  return net(x, use_dropout)

In [5]:
# hyperparameters
lr = 1e-3
reg = 1e-4

# instantiate the network
net = hk.transform(net_fn)

# build the optimizer
opt_init, opt_update, opt_get_params = optimizers.rmsprop(lr)

# standard L2-regularized crossentropy loss function
def loss(params, rng, batch):
    logits = net.apply(params, rng, batch, use_dropout=True)
    labels = jax.nn.one_hot(batch['label'], 10)

    l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) 
                        for p in jax.tree_leaves(params))
    softmax_crossent = - jnp.mean(labels * jax.nn.log_softmax(logits))

    return softmax_crossent + reg * l2_loss

@jax.jit
def accuracy(params, batch):
  preds = net.apply(params, jax.random.PRNGKey(101), batch, use_dropout=False)
  return jnp.mean(jnp.argmax(preds, axis=-1) == batch['label'])

@jax.jit
def train_step(i, opt_state, rng, batch):
	params = opt_get_params(opt_state)
	fx, dx = jax.value_and_grad(loss)(params, rng, batch)
	opt_state = opt_update(i, dx, opt_state)
	return fx, opt_state

## Load the Initialization, Val and Test Batches & Do the Optimization

In [None]:
init_batches = load_dataset("train", is_training=True, batch_size=256)
val_batches = load_dataset("train", is_training=False, batch_size=1_000)
test_batches = load_dataset("test", is_training=False, batch_size=1_000)

In [7]:
%%time

# intialize the paramaeters
params = net.init(jax.random.PRNGKey(42), next(init_batches), use_dropout=True)
opt_state = opt_init(params)

# initialize a key for the dropout
rng = jax.random.PRNGKey(2)

for epoch in range(100):
	 #generate a shuffled epoch of training data
  train_batches = load_dataset("train", is_training=True,
                              batch_size=256, repeat=False, seed=epoch)
  
  for batch in train_batches:
    # run an optimization step
    train_loss, opt_state = train_step(epoch, opt_state, rng, batch)
    
    # make more rng for the dropout
    rng, _ = jax.random.split(rng)
	
  if epoch % 5 == 0:
    params = opt_get_params(opt_state)
    val_acc = accuracy(params, next(val_batches))
    test_acc = accuracy(params, next(test_batches))
    print(f"epoch = {epoch}"
          f" | train loss = {train_loss:.4f}"
          f" | val acc = {val_acc:.3f}"
          f" | test acc = {test_acc:.3f}")

epoch = 0 | train loss = 0.1405 | val acc = 0.489 | test acc = 0.515
epoch = 5 | train loss = 0.0659 | val acc = 0.788 | test acc = 0.688
epoch = 10 | train loss = 0.0596 | val acc = 0.818 | test acc = 0.669
epoch = 15 | train loss = 0.0554 | val acc = 0.896 | test acc = 0.702
epoch = 20 | train loss = 0.0598 | val acc = 0.880 | test acc = 0.646
epoch = 25 | train loss = 0.0547 | val acc = 0.939 | test acc = 0.709
epoch = 30 | train loss = 0.0504 | val acc = 0.966 | test acc = 0.714
epoch = 35 | train loss = 0.0502 | val acc = 0.953 | test acc = 0.705
epoch = 40 | train loss = 0.0637 | val acc = 0.954 | test acc = 0.723
epoch = 45 | train loss = 0.0494 | val acc = 0.957 | test acc = 0.718
epoch = 50 | train loss = 0.0472 | val acc = 0.952 | test acc = 0.731
epoch = 55 | train loss = 0.0458 | val acc = 0.972 | test acc = 0.717
epoch = 60 | train loss = 0.0503 | val acc = 0.952 | test acc = 0.730
epoch = 65 | train loss = 0.0490 | val acc = 0.962 | test acc = 0.705
epoch = 70 | train los