In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
!nvidia-smi

Wed May 27 06:54:58 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 440.64.00    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  TITAN Xp            Off  | 00000000:05:00.0 Off |                  N/A |
| 39%   64C    P2    64W / 250W |  11737MiB / 12194MiB |     11%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN Xp            Off  | 00000000:09:00.0 Off |                  N/A |
| 38%   59C    P2    65W / 250W |  11232MiB / 12196MiB |      3%      Default |
+-------------------------------+----------------------+----------------------+
                                                                            

In [5]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [7]:
import sys
import time
from collections import defaultdict
import numpy as onp

# PyTorch

# JAX

In [8]:
import jax
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.scipy.special import logsumexp
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
                                   FanOut, Flatten, GeneralConv, Identity,
                                   MaxPool, Relu, LogSoftmax)

In [None]:
def ConvBlock(kernel_size, filters, strides=(2,2)):
    ks = kernel_size
    f1, f2, f3 = filters
    
    Main = stax.serial(
        Conv(f1, (1,1), strides), BatchNorm(), Relu,
        Conv(f2, (ks,ks), padding='SAME'), BatchNorm(), Relu,
        Conv(f3, (1,1)), BatchNorm()
    )
    Shortcut = stax.serial(Conv(f3, (1,1), strides), BatchNorm())
    
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)

In [None]:
def IdentityBlock(kernel_size, filters):
    ks = kernel_size
    f1, f2 = filters
    def make_main(input_shape):
        return stax.serial(
            Conv(f1, (1,1)), BatchNorm(), Relu,
            Conv(f2, (ks,ks), padding='SAME'), BatchNorm(), Relu,
            Conv(input_shape[3], (1,1)), BatchNorm()
        )
    Main = stax.shape_dependent(make_main)
    return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)

In [None]:
def ResNet50(num_classes):
    return stax.serial(
        GeneralConv(('NCHW', 'OIHW', 'NHWC'), 64, (7,7), (2,2), 'SAME'),
        BatchNorm(), Relu, MaxPool((3,3), strides=(2,2)),
        ConvBlock(3, [64, 64, 256], strides=(1,1)),
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        AvgPool((7,7)), Flatten, Dense(num_classes), LogSoftmax
    )

In [None]:
rng_key = random.PRNGKey(0)

In [None]:
batch_size = 1
num_classes = 10
input_shape = (batch_size, 3, 224, 224)
step_size = 0.1
num_steps = 10

In [None]:
init_fun, predict_fun = ResNet50(num_classes)

In [None]:
_, init_params = init_fun(rng_key, input_shape)

In [None]:
def loss(params, batch):
    inputs, targets = batch
    logits = predict_fun(params, inputs)
    return -np.sum(logits*targets)

In [None]:
def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=-1)
    predicted_class = np.argmax(predict_fun(params, inputs), axis=-1)
    return np.mean(predicted_class == target_class)

In [None]:
def synth_batches():
    rng = onp.random.RandomState(0)
    while True:
        images = rng.rand(*input_shape).astype('float32')
        labels = rng.randint(num_classes, size=(batch_size, 1))
        onehot_labels = labels == onp.arange(num_classes)
        yield images, onehot_labels

In [None]:
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)

In [None]:
batches = synth_batches()

In [None]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

In [None]:
opt_state = opt_init(init_params)

In [None]:
for i in range(num_steps):
    sys.stdout.write(f'\rEpoch {i+1}/{num_steps}...')
    sys.stdout.flush()
    opt_state = update(i, opt_state, next(batches))
trained_params = get_params(opt_state)

In [None]:
!nvidia-smi