In [1]:
import data as d

In [3]:
ds = d.dataset('test', batch_size=4)
for imgs, labels in ds:
    print("imgs", imgs.shape)
    print(imgs[0])
    print("labels", labels)
    break

W0904 19:51:01.526964 139795642832704 ag_logging.py:146] AutoGraph could not transform <function dataset.<locals>.to_float at 0x7f23900d3950> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'
imgs (4, 64, 64, 3)
[[[0.09803922 0.16470589 0.26666668]
  [0.10196079 0.16862746 0.27058825]
  [0.09803922 0.16470589 0.26666668]
  ...
  [0.09803922 0.16470589 0.27450982]
  [0.09411765 0.17254902 0.2784314 ]
  [0.09019608 0.16862746 0.27450982]]

 [[0.09803922 0.16470589 0.26666668]
  [0.10196079 0.16862746 0.27058825]
  [0.09803922 0.16470589 0.26666668]
  ...
  [0.09803922 0.16470589 0.27450982]
  [0.09019608 0.16862746 0.27450982]
  [0.09411765 0.17254902 0.2784314 ]]

 [[0.09411765 0.17254902 0.27058825]
  [0.09019608 0.16862746 0.26666668]
  [0.09411765 0.16078432 0.27058825]
  ...
  [0.09411765 0.17254902 0.27058825]
  [0.09411765 0.16078432 0.27058825]
  [0.09803922 0.16470589 0.27450982]]

 ...

 [[0.09019608 0.16862746 0.26666668]
  [0.09019608 0.16862746 0.266

In [82]:
import jax
import jax.numpy as jnp
from jax import random, lax, vmap
from jax.nn.initializers import glorot_normal, he_normal
from jax.nn.functions import gelu
from functools import partial
import objax
from objax.variable import TrainVar


def _conv_layer(stride, activation, inp, kernel, bias):
    no_dilation = (1, 1)
    some_height_width = 10  # values don't matter; just shape of input
    input_shape = (1, some_height_width, some_height_width, 3)
    kernel_shape = (3, 3, 1, 1)
    input_kernel_output = ('NHWC', 'HWIO', 'NHWC')
    conv_dimension_numbers = lax.conv_dimension_numbers(input_shape,
                                                        kernel_shape,
                                                        input_kernel_output)
    block = lax.conv_general_dilated(inp, kernel, (stride, stride),
                                     'VALID', no_dilation, no_dilation,
                                     conv_dimension_numbers)
    if bias is not None:
        block += bias
    if activation:
        block = activation(block)
    return block

def _dense_layer(inp, activation, kernel, bias):
    block = jnp.dot(inp, kernel) + bias
    if activation:
        block = activation(block)
    return block

# def _conv_block_without_bias(stride, with_non_linearity, inp, kernel):
#     # the need for this method feels a bit clunky :/ is there a better
#     # way to vmap with the None?
#     return _conv_block(stride, with_non_linearity, inp, kernel, None)




In [83]:
class NonEnsembleNet(objax.Module):

    def __init__(self, num_classes, dense_kernel_size=32, seed=0):

        key = random.PRNGKey(seed)
        subkeys = random.split(key, 8)

        # conv stack kernels and biases
        self.conv_kernels = objax.ModuleList()
        self.conv_biases = objax.ModuleList()
        input_channels = 3
        for i, output_channels in enumerate([32, 64, 64, 64]):
            self.conv_kernels.append(TrainVar(he_normal()(subkeys[i], (3, 3, input_channels,
                                                                       output_channels))))
            self.conv_biases.append(TrainVar(jnp.zeros((output_channels))))
            input_channels = output_channels

        # dense layer kernel and bias
        self.dense_kernel = TrainVar(he_normal()(subkeys[6], (output_channels, dense_kernel_size)))
        self.dense_bias = TrainVar(jnp.zeros((dense_kernel_size)))

        # classifier layer kernel and bias
        self.logits_kernel = TrainVar(glorot_normal()(subkeys[6], (dense_kernel_size, num_classes)))
        self.logits_bias = TrainVar(jnp.zeros((num_classes)))

    def logits(self, inp):        
        # conv stack -> (B, 3, 3, 64)
        y = inp
        for kernel, bias in zip(self.conv_kernels, self.conv_biases):
            y = _conv_layer(2, gelu, y, kernel.value, bias.value)
            
        # global spatial pooling -> (B, 64)
        y = jnp.mean(y, axis=(1, 2))
            
        # dense layer with non linearity -> (B, 32)
        y = _dense_layer(y, gelu, self.dense_kernel.value, self.dense_bias.value)
        
        # dense layer with no activation to number classes -> (B, num_classes)
        logits = _dense_layer(y, None, self.logits_kernel.value, self.logits_bias.value)
        
        return logits
        
    def predict(self, inp):
        return jax.nn.softmax(self.logits(inp))



In [84]:
net = NonEnsembleNet(num_classes=10)
print(jnp.around(net.logits(imgs), 3))
print(jnp.around(net.predict(imgs), 3))

[[-0.059      -0.036      -0.083       0.037       0.043       0.005
  -0.089       0.038       0.006       0.051     ]
 [-0.074      -0.045      -0.11100001  0.053       0.054      -0.001
  -0.109       0.047      -0.004       0.064     ]
 [-0.266      -0.18900001 -0.208       0.17300001  0.141      -0.109
  -0.314       0.125       0.041       0.09900001]
 [-0.08800001 -0.054      -0.11800001  0.063       0.062      -0.015
  -0.13900001  0.057      -0.002       0.068     ]]
[[0.09500001 0.097      0.093      0.105      0.105      0.101
  0.09200001 0.105      0.101      0.10600001]
 [0.094      0.097      0.09       0.10700001 0.10700001 0.101
  0.09100001 0.10600001 0.101      0.108     ]
 [0.079      0.086      0.08400001 0.123      0.119      0.093
  0.07600001 0.11700001 0.108      0.11400001]
 [0.093      0.096      0.09       0.108      0.108      0.1
  0.08800001 0.10700001 0.101      0.109     ]]


In [32]:
imgs.shape

(4, 64, 64, 3)

In [91]:
def cross_entropy(imgs, labels):
    logits = net.logits(imgs)
    return jnp.mean(objax.functional.loss.cross_entropy_logits_sparse(logits, labels))

gradient_loss = objax.GradValues(cross_entropy, net.vars())
optimiser = objax.optimizer.Adam(net.vars())
lr = 1e-3

# create a jitted training step
def train_step(imgs, labels):
    grads, loss = gradient_loss(imgs, labels)
    optimiser(lr, grads)
    return loss




In [93]:
for _ in range(10):
    print(train_step(imgs, labels))

[DeviceArray(1.2861435, dtype=float32)]
[DeviceArray(1.175744, dtype=float32)]
[DeviceArray(1.0832641, dtype=float32)]
[DeviceArray(1.0269818, dtype=float32)]
[DeviceArray(0.9848052, dtype=float32)]
[DeviceArray(0.93037367, dtype=float32)]
[DeviceArray(0.8638202, dtype=float32)]
[DeviceArray(0.8097404, dtype=float32)]
[DeviceArray(0.77046627, dtype=float32)]
[DeviceArray(0.70041114, dtype=float32)]


In [97]:
jnp.around(net.predict(imgs), 2)

DeviceArray([[0.        , 0.38      , 0.        , 0.17      , 0.01      ,
              0.        , 0.        , 0.02      , 0.01      , 0.41      ],
             [0.        , 0.35999998, 0.        , 0.31      , 0.        ,
              0.        , 0.        , 0.01      , 0.02      , 0.29      ],
             [0.        , 0.        , 0.        , 0.02      , 0.        ,
              0.        , 0.        , 0.        , 0.97999996, 0.        ],
             [0.        , 0.22      , 0.        , 0.53      , 0.        ,
              0.        , 0.        , 0.        , 0.05      , 0.19      ]],            dtype=float32)

In [98]:
labels

array([9, 1, 8, 3])