In [1]:
import data as d

In [2]:
ds = d.dataset('test', batch_size=4)
for imgs, labels in ds:
    print("imgs", imgs.shape)
    print(imgs[0])
    print("labels", labels)
    break

W0904 22:26:17.107448 140421431543616 ag_logging.py:146] AutoGraph could not transform <function dataset.<locals>.to_float at 0x7fb65d249620> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'
imgs (4, 64, 64, 3)
[[[0.14117648 0.21960784 0.3137255 ]
  [0.13333334 0.21960784 0.30980393]
  [0.15294118 0.23921569 0.32156864]
  ...
  [0.15294118 0.25490198 0.32156864]
  [0.15686275 0.25882354 0.3254902 ]
  [0.14901961 0.2509804  0.31764707]]

 [[0.13725491 0.21960784 0.3019608 ]
  [0.13725491 0.22352941 0.30588236]
  [0.14509805 0.24313726 0.32156864]
  ...
  [0.15294118 0.25490198 0.32156864]
  [0.15294118 0.25490198 0.32156864]
  [0.14901961 0.2509804  0.31764707]]

 [[0.15686275 0.24705882 0.31764707]
  [0.14509805 0.24705882 0.3137255 ]
  [0.15686275 0.25882354 0.3254902 ]
  ...
  [0.15294118 0.25490198 0.32156864]
  [0.14901961 0.25490198 0.32156864]
  [0.14509805 0.2509804  0.31764707]]

 ...

 [[0.15686275 0.25882354 0.31764707]
  [0.15686275 0.25882354 0.317

In [15]:
import jax
import jax.numpy as jnp
from jax import random, lax, vmap
from jax.nn.initializers import glorot_normal, he_normal
from jax.nn.functions import gelu
from functools import partial
import objax
from objax.variable import TrainVar


def _conv_layer(stride, activation, inp, kernel, bias):
    no_dilation = (1, 1)
    some_height_width = 10  # values don't matter; just shape of input
    input_shape = (1, some_height_width, some_height_width, 3)
    kernel_shape = (3, 3, 1, 1)
    input_kernel_output = ('NHWC', 'HWIO', 'NHWC')
    conv_dimension_numbers = lax.conv_dimension_numbers(input_shape,
                                                        kernel_shape,
                                                        input_kernel_output)
    block = lax.conv_general_dilated(inp, kernel, (stride, stride),
                                     'VALID', no_dilation, no_dilation,
                                     conv_dimension_numbers)
    if bias is not None:
        block += bias
    if activation:
        block = activation(block)
    return block

def _dense_layer(activation, inp, kernel, bias):
    block = jnp.dot(inp, kernel) + bias
    if activation:
        block = activation(block)
    return block

# def _conv_block_without_bias(stride, with_non_linearity, inp, kernel):
#     # the need for this method feels a bit clunky :/ is there a better
#     # way to vmap with the None?
#     return _conv_block(stride, with_non_linearity, inp, kernel, None)




In [34]:
class NonEnsembleNet(objax.Module):

    def __init__(self, num_classes, dense_kernel_size=32, seed=0):

        key = random.PRNGKey(seed)
        subkeys = random.split(key, 8)

        # conv stack kernels and biases
        self.conv_kernels = objax.ModuleList()
        self.conv_biases = objax.ModuleList()
        input_channels = 3
        for i, output_channels in enumerate([32, 64, 64, 64]):
            self.conv_kernels.append(TrainVar(he_normal()(subkeys[i], (3, 3, input_channels,
                                                                       output_channels))))
            self.conv_biases.append(TrainVar(jnp.zeros((output_channels))))
            input_channels = output_channels

        # dense layer kernel and bias
        self.dense_kernel = TrainVar(he_normal()(subkeys[6], (output_channels, dense_kernel_size)))
        self.dense_bias = TrainVar(jnp.zeros((dense_kernel_size)))

        # classifier layer kernel and bias
        self.logits_kernel = TrainVar(glorot_normal()(subkeys[6], (dense_kernel_size, num_classes)))
        self.logits_bias = TrainVar(jnp.zeros((num_classes)))

    def logits(self, inp):        
        # conv stack -> (B, 3, 3, 64)
        y = inp
        for kernel, bias in zip(self.conv_kernels, self.conv_biases):
            y = _conv_layer(2, gelu, y, kernel.value, bias.value)
            
        # global spatial pooling -> (B, 64)
        y = jnp.mean(y, axis=(1, 2))
            
        # dense layer with non linearity -> (B, 32)
        y = _dense_layer(gelu, y, self.dense_kernel.value, self.dense_bias.value)
        
        # dense layer with no activation to number classes -> (B, num_classes)
        logits = _dense_layer(None, y, self.logits_kernel.value, self.logits_bias.value)
        
        return logits
        
    def predict(self, inp):
        return jax.nn.softmax(self.logits(inp))



In [30]:
class EnsembleNet(objax.Module):

    def __init__(self, num_models, num_classes, dense_kernel_size=32, seed=0):

        key = random.PRNGKey(seed)
        subkeys = random.split(key, 8)

        # conv stack kernels and biases
        self.conv_kernels = objax.ModuleList()
        self.conv_biases = objax.ModuleList()
        input_channels = 3
        for i, output_channels in enumerate([32, 64, 64, 64]):
            self.conv_kernels.append(TrainVar(he_normal()(subkeys[i], (num_models, 3, 3, input_channels,
                                                                       output_channels))))
            self.conv_biases.append(TrainVar(jnp.zeros((num_models, output_channels))))
            input_channels = output_channels

        # dense layer kernel and bias
        self.dense_kernels = TrainVar(he_normal()(subkeys[6], 
                                                 (num_models, output_channels, dense_kernel_size)))
        self.dense_biases = TrainVar(jnp.zeros((num_models, dense_kernel_size)))

        # classifier layer kernel and bias
        self.logits_kernel = TrainVar(glorot_normal()(subkeys[6], 
                                                      (num_models, dense_kernel_size, num_classes)))
        self.logits_biases = TrainVar(jnp.zeros((num_models, num_classes)))

    def logits(self, inp):                
        # the first call vmaps over the first conv params for a single input
        y = vmap(partial(_conv_layer, 2, gelu, inp))(
            self.conv_kernels[0].value, self.conv_biases[0].value)        

        # subsequent calls vmap over both the prior input and the conv params
        # the first representing the batched input with the second representing
        # the batched models (i.e. the ensemble)
        for conv_kernel, conv_bias in zip(self.conv_kernels[1:],
                                          self.conv_biases[1:]):
            y = vmap(partial(_conv_layer, 2, gelu))(
                y, conv_kernel.value, conv_bias.value)
            
    
        # global spatial pooling
        # (M, B, 64)
        y = jnp.mean(y, axis=(2, 3))

        # dense layer with non linearity
        # (M, B, 32)
        y = vmap(partial(_dense_layer, gelu))(y, self.dense_kernels.value, self.dense_biases.value)
        
        # dense layer with no activation to number classes 
        # (M, B, num_classes)
        logits = vmap(partial(_dense_layer, None))(y, self.logits_kernel.value, self.logits_biases.value)        
        return logits
        
    def predict(self, inp):
        return jax.nn.softmax(self.logits(inp))



In [33]:
net = EnsembleNet(num_models=3, num_classes=10)
print(net.logits(imgs).shape)
print(net.predict(imgs))

(3, 4, 10)
[[[0.10004424 0.09997072 0.10037523 0.09999789 0.09989534 0.09969421
   0.0997958  0.0999829  0.09994222 0.10030156]
  [0.10007226 0.09998271 0.10032347 0.09999003 0.09988236 0.09975057
   0.09981922 0.09997445 0.09994675 0.10025822]
  [0.10015408 0.10009208 0.10018627 0.09996849 0.10000385 0.09928628
   0.09951436 0.10046441 0.09990092 0.10042922]
  [0.09998829 0.09993898 0.10031915 0.09999322 0.09994603 0.09977533
   0.09986364 0.09995876 0.09999613 0.10022039]]

 [[0.10038323 0.09982816 0.09969807 0.10003973 0.10005305 0.10000939
   0.09987422 0.10008911 0.09997763 0.10004737]
  [0.10029979 0.09988333 0.09974554 0.10003419 0.10004012 0.09995164
   0.09993222 0.10008802 0.09998222 0.10004304]
  [0.10099845 0.09958512 0.09945692 0.09989917 0.10006463 0.10026194
   0.09961688 0.1000125  0.10013781 0.09996644]
  [0.1003595  0.09983999 0.09973717 0.1000357  0.10005657 0.10006701
   0.09982106 0.1000558  0.09995658 0.10007066]]

 [[0.0995431  0.10018515 0.10007327 0.10008618 0.

In [35]:
net = NonEnsembleNet(num_classes=10)
print(jnp.around(net.logits(imgs), 3))
print(jnp.around(net.predict(imgs), 3))

[[-0.074      -0.045      -0.11100001  0.053       0.054      -0.001
  -0.109       0.047      -0.004       0.064     ]
 [-0.059      -0.036      -0.083       0.037       0.043       0.005
  -0.089       0.038       0.006       0.051     ]
 [-0.42200002 -0.28800002 -0.252       0.27400002  0.19700001 -0.22000001
  -0.45200002  0.216       0.1         0.115     ]
 [-0.063      -0.039      -0.10300001  0.048       0.049      -0.003
  -0.10200001  0.041      -0.011       0.06500001]]
[[0.094      0.097      0.09       0.10700001 0.10700001 0.101
  0.09100001 0.10600001 0.101      0.108     ]
 [0.09500001 0.097      0.093      0.105      0.105      0.101
  0.09200001 0.105      0.101      0.10600001]
 [0.068      0.078      0.081      0.13700001 0.127      0.083
  0.066      0.12900001 0.115      0.11700001]
 [0.09500001 0.097      0.09100001 0.10600001 0.10600001 0.101
  0.09100001 0.105      0.1        0.108     ]]


In [32]:
imgs.shape

(4, 64, 64, 3)

In [91]:
def cross_entropy(imgs, labels):
    logits = net.logits(imgs)
    return jnp.mean(objax.functional.loss.cross_entropy_logits_sparse(logits, labels))

gradient_loss = objax.GradValues(cross_entropy, net.vars())
optimiser = objax.optimizer.Adam(net.vars())
lr = 1e-3

# create a jitted training step
def train_step(imgs, labels):
    grads, loss = gradient_loss(imgs, labels)
    optimiser(lr, grads)
    return loss

train_step = objax.Jit(train_step,
                       gradient_loss.vars() + optimiser.vars())



In [93]:
for _ in range(10):
    print(train_step(imgs, labels))

[DeviceArray(1.2861435, dtype=float32)]
[DeviceArray(1.175744, dtype=float32)]
[DeviceArray(1.0832641, dtype=float32)]
[DeviceArray(1.0269818, dtype=float32)]
[DeviceArray(0.9848052, dtype=float32)]
[DeviceArray(0.93037367, dtype=float32)]
[DeviceArray(0.8638202, dtype=float32)]
[DeviceArray(0.8097404, dtype=float32)]
[DeviceArray(0.77046627, dtype=float32)]
[DeviceArray(0.70041114, dtype=float32)]


In [97]:
jnp.around(net.predict(imgs), 2)

DeviceArray([[0.        , 0.38      , 0.        , 0.17      , 0.01      ,
              0.        , 0.        , 0.02      , 0.01      , 0.41      ],
             [0.        , 0.35999998, 0.        , 0.31      , 0.        ,
              0.        , 0.        , 0.01      , 0.02      , 0.29      ],
             [0.        , 0.        , 0.        , 0.02      , 0.        ,
              0.        , 0.        , 0.        , 0.97999996, 0.        ],
             [0.        , 0.22      , 0.        , 0.53      , 0.        ,
              0.        , 0.        , 0.        , 0.05      , 0.19      ]],            dtype=float32)

In [98]:
labels

array([9, 1, 8, 3])