In [6]:
import models
import numpy as np
import jax.numpy as jnp
import objax
import jax

def pp(a):
    print(a.shape)
    print(jnp.around(a, 3))

In [13]:
M = 3
B = 4
C = 5

imgs = jnp.array(np.random.uniform(0, 1, (M, B, 64, 64, 3)))
labels = jnp.array(np.random.randint(0, C, (M, B,)))

labels

DeviceArray([[1, 2, 0, 2],
             [3, 4, 0, 1],
             [0, 2, 2, 0]], dtype=int32)

In [18]:
net = models.EnsembleNet(num_models=M, num_classes=C)

In [19]:
logits = net.logits(imgs, single_result=False, logits_dropout=False)
pp(logits)

(3, 4, 5)
[[[ 0.004 -0.001 -0.003  0.    -0.005]
  [ 0.003 -0.003 -0.004  0.002 -0.006]
  [ 0.003 -0.004 -0.004  0.    -0.007]
  [ 0.004 -0.002 -0.002  0.002 -0.007]]

 [[ 0.    -0.005  0.005 -0.005 -0.003]
  [-0.003 -0.01   0.002 -0.006 -0.003]
  [-0.002 -0.007  0.003 -0.006 -0.002]
  [-0.003 -0.009  0.003 -0.005 -0.001]]

 [[-0.001 -0.003 -0.012 -0.001 -0.002]
  [-0.001 -0.002 -0.014 -0.001 -0.002]
  [ 0.    -0.002 -0.013  0.001 -0.004]
  [-0.001 -0.004 -0.015 -0.001 -0.001]]]


In [20]:
ensembled_logits = jnp.sum(logits, axis=0)
pp(ensembled_logits)

(4, 5)
[[ 0.003 -0.009 -0.01  -0.006 -0.01 ]
 [-0.001 -0.015 -0.016 -0.005 -0.01 ]
 [ 0.001 -0.013 -0.014 -0.004 -0.014]
 [ 0.001 -0.015 -0.014 -0.003 -0.009]]


In [21]:
tiled_ensemble_logits = jnp.tile(ensembled_logits, (M, 1, 1))
pp(tiled_ensemble_logits)

(3, 4, 5)
[[[ 0.003 -0.009 -0.01  -0.006 -0.01 ]
  [-0.001 -0.015 -0.016 -0.005 -0.01 ]
  [ 0.001 -0.013 -0.014 -0.004 -0.014]
  [ 0.001 -0.015 -0.014 -0.003 -0.009]]

 [[ 0.003 -0.009 -0.01  -0.006 -0.01 ]
  [-0.001 -0.015 -0.016 -0.005 -0.01 ]
  [ 0.001 -0.013 -0.014 -0.004 -0.014]
  [ 0.001 -0.015 -0.014 -0.003 -0.009]]

 [[ 0.003 -0.009 -0.01  -0.006 -0.01 ]
  [-0.001 -0.015 -0.016 -0.005 -0.01 ]
  [ 0.001 -0.013 -0.014 -0.004 -0.014]
  [ 0.001 -0.015 -0.014 -0.003 -0.009]]]


In [22]:
held_one_out_logits = tiled_ensemble_logits - logits
pp(held_one_out_logits)

(3, 4, 5)
[[[-0.001 -0.008 -0.007 -0.006 -0.005]
  [-0.004 -0.012 -0.012 -0.007 -0.004]
  [-0.002 -0.009 -0.01  -0.004 -0.007]
  [-0.003 -0.013 -0.012 -0.006 -0.002]]

 [[ 0.004 -0.004 -0.015 -0.001 -0.007]
  [ 0.002 -0.005 -0.018  0.001 -0.008]
  [ 0.003 -0.006 -0.017  0.002 -0.011]
  [ 0.003 -0.006 -0.017  0.001 -0.008]]

 [[ 0.004 -0.006  0.002 -0.005 -0.008]
  [ 0.    -0.013 -0.001 -0.004 -0.009]
  [ 0.001 -0.011 -0.001 -0.005 -0.009]
  [ 0.001 -0.011  0.001 -0.002 -0.008]]]


In [23]:
jnp.argmax(held_one_out_logits, axis=-1)

DeviceArray([[0, 0, 0, 4],
             [0, 0, 0, 0],
             [0, 0, 0, 0]], dtype=int32)

In [24]:
logits = net.logits(imgs, single_result=False, logits_dropout=False)
ensembled_logits = jnp.sum(logits, axis=0)
tiled_ensemble_logits = jnp.tile(ensembled_logits, (M, 1, 1))
held_one_out_logits = tiled_ensemble_logits - logits
predictions = jnp.argmax(held_one_out_logits, axis=-1)
print(predictions)

[[0 0 0 4]
 [0 0 0 0]
 [0 0 0 0]]
