In [1]:
import os

import numpy as np
import tensorflow_datasets as tfds
from tqdm import trange

import objax
from objax.util import EasyDict

In [2]:
def simple_net_block(nin, nout):
    return objax.nn.Sequential([
        objax.nn.Conv2D(nin, nout, k=3), objax.functional.leaky_relu,
        objax.functional.max_pool_2d,
        objax.nn.Conv2D(nout, nout, k=3), objax.functional.leaky_relu,
    ])


class SimpleNet(objax.Module):
    def __init__(self, nclass, colors, n):
        self.pre_conv = objax.nn.Sequential([objax.nn.Conv2D(colors, n, k=3), objax.functional.leaky_relu])
        self.block1 = simple_net_block(1 * n, 2 * n)
        self.block2 = simple_net_block(2 * n, 4 * n)
        self.post_conv = objax.nn.Conv2D(4 * n, nclass, k=3)

    def __call__(self, x, training=False):  # x = (batch, colors, height, width)
        y = self.pre_conv(x)
        y = self.block1(y)
        y = self.block2(y)
        logits = self.post_conv(y).mean((2, 3))  # logits = (batch, nclass)
        if training:
            return logits
        return objax.functional.softmax(logits)


# Data
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))
train = EasyDict(image=data['train']['image'].transpose(0, 3, 1, 2) / 255, label=data['train']['label'])
test = EasyDict(image=data['test']['image'].transpose(0, 3, 1, 2) / 255, label=data['test']['label'])
del data


def augment(x, shift=4):  # Shift all images in the batch by up to "shift" pixels in any direction.
    x_pad = np.pad(x, [[0, 0], [0, 0], [shift, shift], [shift, shift]])
    rx, ry = np.random.randint(0, shift, size=2)
    return x_pad[:, :, rx:rx + 28, ry:ry + 28]

# Settings
batch = 512
test_batch = 2048
weight_decay = 0.0001
epochs = 40
lr = 0.0004 * (batch / 64)
train_size = train.image.shape[0]

# Model
model = SimpleNet(nclass=10, colors=1, n=16)  # Use higher values of n to get higher accuracy.
model_ema = objax.optimizer.ExponentialMovingAverageModule(model, momentum=0.999, debias=True)
opt = objax.optimizer.Adam(model.vars())


@objax.Function.with_vars(model.vars())
def loss(x, y):
    logits = model(x, training=True)
    loss_xe = objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()
    loss_l2 = 0.5 * sum((v.value ** 2).sum() for k, v in model.vars().items() if k.endswith('.w'))
    return loss_xe + weight_decay * loss_l2, {'loss/xe': loss_xe, 'loss/l2': loss_l2}


gv = objax.GradValues(loss, model.vars())


@objax.Function.with_vars(model.vars() + gv.vars() + opt.vars() + model_ema.vars())
def train_op(x, y):
    g, v = gv(x, y)
    opt(lr, g)
    model_ema.update_ema()
    return v


train_op = objax.Jit(train_op)  # Compile train_op to make it run faster.
predict = objax.Jit(model_ema)  # Compile predict to make it run faster.


[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/jovyan/TFDS/mnist/3.0.1...[0m


Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /home/jovyan/TFDS/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [12]:
g = objax.Grad(loss, model.vars())
single_gradients = objax.Vectorize(g, batch_axis=(0, 0)) 


batch = 10
sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])
x = train.image[sel]
y = train.label[sel]

# Compute standard gradients
print([v.shape for v in g(x, y)])              # [(4,), (3, 4)]


[(16, 1, 1), (3, 3, 1, 16), (32, 1, 1), (3, 3, 16, 32), (32, 1, 1), (3, 3, 32, 32), (64, 1, 1), (3, 3, 32, 64), (64, 1, 1), (3, 3, 64, 64), (10, 1, 1), (3, 3, 64, 10)]


In [13]:
# Compute per batch entry gradients
print([v.shape for v in single_gradients(np.expand_dims(x,1), np.expand_dims(y,1))])   # [(10, 4), (10, 3, 4)]

[(10, 16, 1, 1), (10, 3, 3, 1, 16), (10, 32, 1, 1), (10, 3, 3, 16, 32), (10, 32, 1, 1), (10, 3, 3, 32, 32), (10, 64, 1, 1), (10, 3, 3, 32, 64), (10, 64, 1, 1), (10, 3, 3, 64, 64), (10, 10, 1, 1), (10, 3, 3, 64, 10)]


In [3]:
# Training
print(model.vars())
for epoch in range(epochs):
    # Train one epoch
    loop = trange(0, train_size, batch,
                  leave=False, unit='img', unit_scale=batch,
                  desc='Epoch %d/%d ' % (1 + epoch, epochs))
    for it in loop:
        sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])
        v = train_op(augment(train.image[sel]), train.label[sel])

    # Eval
    accuracy = 0
    for it in trange(0, test.image.shape[0], test_batch, leave=False, desc='Evaluating'):
        x = test.image[it: it + test_batch]
        xl = test.label[it: it + test_batch]
        accuracy += (np.argmax(predict(x), axis=1) == xl).sum()
    accuracy /= test.image.shape[0]
    print(f'Epoch {epoch + 1:04d}  Accuracy {100 * accuracy:.2f}')

Epoch 1/40 :   0%|          | 0/60416 [00:00<?, ?img/s]

(SimpleNet).pre_conv(Sequential)[0](Conv2D).b       16 (16, 1, 1)
(SimpleNet).pre_conv(Sequential)[0](Conv2D).w      144 (3, 3, 1, 16)
(SimpleNet).block1(Sequential)[0](Conv2D).b         32 (32, 1, 1)
(SimpleNet).block1(Sequential)[0](Conv2D).w       4608 (3, 3, 16, 32)
(SimpleNet).block1(Sequential)[3](Conv2D).b         32 (32, 1, 1)
(SimpleNet).block1(Sequential)[3](Conv2D).w       9216 (3, 3, 32, 32)
(SimpleNet).block2(Sequential)[0](Conv2D).b         64 (64, 1, 1)
(SimpleNet).block2(Sequential)[0](Conv2D).w      18432 (3, 3, 32, 64)
(SimpleNet).block2(Sequential)[3](Conv2D).b         64 (64, 1, 1)
(SimpleNet).block2(Sequential)[3](Conv2D).w      36864 (3, 3, 64, 64)
(SimpleNet).post_conv(Conv2D).b                     10 (10, 1, 1)
(SimpleNet).post_conv(Conv2D).w                   5760 (3, 3, 64, 10)
+Total(12)                                       75242


Epoch 2/40 :  10%|█         | 6144/60416 [00:00<00:00, 54395.71img/s] 

Epoch 0001  Accuracy 81.13


Epoch 3/40 :  11%|█         | 6656/60416 [00:00<00:01, 52973.48img/s] 

Epoch 0002  Accuracy 92.02


Epoch 4/40 :  10%|█         | 6144/60416 [00:00<00:00, 55131.54img/s] 

Epoch 0003  Accuracy 94.80


Epoch 5/40 :  10%|█         | 6144/60416 [00:00<00:00, 56902.31img/s] 

Epoch 0004  Accuracy 96.06


Epoch 6/40 :  12%|█▏        | 7168/60416 [00:00<00:00, 64699.25img/s] 

Epoch 0005  Accuracy 96.86


Epoch 7/40 :  10%|█         | 6144/60416 [00:00<00:00, 55611.00img/s] 

Epoch 0006  Accuracy 97.38


Epoch 8/40 :  10%|█         | 6144/60416 [00:00<00:00, 54334.01img/s] 

Epoch 0007  Accuracy 97.80


Epoch 9/40 :  10%|█         | 6144/60416 [00:00<00:00, 54880.51img/s] 

Epoch 0008  Accuracy 98.09


Epoch 10/40 :  10%|█         | 6144/60416 [00:00<00:00, 56188.90img/s]

Epoch 0009  Accuracy 98.37


Epoch 11/40 :  10%|█         | 6144/60416 [00:00<00:00, 54351.54img/s] 

Epoch 0010  Accuracy 98.54


Epoch 12/40 :  10%|█         | 6144/60416 [00:00<00:00, 54324.73img/s] 

Epoch 0011  Accuracy 98.65


Epoch 13/40 :  10%|█         | 6144/60416 [00:00<00:01, 53927.41img/s] 

Epoch 0012  Accuracy 98.79


Epoch 14/40 :  10%|█         | 6144/60416 [00:00<00:00, 57170.69img/s] 

Epoch 0013  Accuracy 98.90


Epoch 15/40 :  10%|█         | 6144/60416 [00:00<00:01, 53870.59img/s] 

Epoch 0014  Accuracy 98.99


Epoch 16/40 :  10%|█         | 6144/60416 [00:00<00:01, 52559.90img/s] 

Epoch 0015  Accuracy 99.04


Epoch 17/40 :  10%|█         | 6144/60416 [00:00<00:00, 54981.09img/s] 

Epoch 0016  Accuracy 99.12


Epoch 18/40 :  10%|█         | 6144/60416 [00:00<00:01, 52460.39img/s] 

Epoch 0017  Accuracy 99.18


Epoch 19/40 :  10%|█         | 6144/60416 [00:00<00:00, 57335.03img/s] 

Epoch 0018  Accuracy 99.26


Epoch 20/40 :  10%|█         | 6144/60416 [00:00<00:01, 53510.62img/s] 

Epoch 0019  Accuracy 99.31


Epoch 21/40 :  10%|█         | 6144/60416 [00:00<00:00, 55123.64img/s] 

Epoch 0020  Accuracy 99.36


Epoch 22/40 :  10%|█         | 6144/60416 [00:00<00:00, 54484.72img/s] 

Epoch 0021  Accuracy 99.39


Epoch 23/40 :  10%|█         | 6144/60416 [00:00<00:01, 54074.40img/s] 

Epoch 0022  Accuracy 99.41


Epoch 24/40 :  10%|█         | 6144/60416 [00:00<00:01, 53878.25img/s] 

Epoch 0023  Accuracy 99.42


Epoch 25/40 :  10%|█         | 6144/60416 [00:00<00:00, 55241.68img/s] 

Epoch 0024  Accuracy 99.43


Epoch 26/40 :  10%|█         | 6144/60416 [00:00<00:00, 54719.00img/s] 

Epoch 0025  Accuracy 99.47


Epoch 27/40 :  10%|█         | 6144/60416 [00:00<00:00, 55292.41img/s] 

Epoch 0026  Accuracy 99.46


Epoch 28/40 :  10%|█         | 6144/60416 [00:00<00:01, 53361.36img/s] 

Epoch 0027  Accuracy 99.47


Epoch 29/40 :  10%|█         | 6144/60416 [00:00<00:01, 54121.87img/s] 

Epoch 0028  Accuracy 99.48


Epoch 30/40 :  10%|█         | 6144/60416 [00:00<00:00, 54410.41img/s] 

Epoch 0029  Accuracy 99.49


Epoch 31/40 :  10%|█         | 6144/60416 [00:00<00:00, 55004.10img/s] 

Epoch 0030  Accuracy 99.48


Epoch 32/40 :  10%|█         | 6144/60416 [00:00<00:01, 53526.85img/s] 

Epoch 0031  Accuracy 99.49


Epoch 33/40 :  10%|█         | 6144/60416 [00:00<00:01, 53645.96img/s] 

Epoch 0032  Accuracy 99.51


Epoch 34/40 :  10%|█         | 6144/60416 [00:00<00:00, 54949.32img/s] 

Epoch 0033  Accuracy 99.50


Epoch 35/40 :  10%|█         | 6144/60416 [00:00<00:00, 55157.50img/s] 

Epoch 0034  Accuracy 99.52


Epoch 36/40 :  10%|█         | 6144/60416 [00:00<00:01, 52774.53img/s] 

Epoch 0035  Accuracy 99.53


Epoch 37/40 :  10%|█         | 6144/60416 [00:00<00:00, 55297.75img/s] 

Epoch 0036  Accuracy 99.49


Epoch 38/40 :  10%|█         | 6144/60416 [00:00<00:00, 55001.16img/s] 

Epoch 0037  Accuracy 99.51


Epoch 39/40 :  10%|█         | 6144/60416 [00:00<00:00, 55309.03img/s] 

Epoch 0038  Accuracy 99.53


Epoch 40/40 :  10%|█         | 6144/60416 [00:00<00:01, 53560.44img/s] 

Epoch 0039  Accuracy 99.54


                                                                       

Epoch 0040  Accuracy 99.55
