In [1]:
import models
import numpy as np
import jax.numpy as jnp
import objax

## non ensemble net
create a non ensemble net and train it

## create ensemble net

In [2]:
M = 2
B = 4
num_classes = 10

imgs = jnp.array(np.random.uniform(0, 1, (M, B, 64, 64, 3)))
labels = jnp.array(np.random.randint(0, num_classes, (M, B,)))

labels

DeviceArray([[0, 0, 6, 0],
             [6, 3, 0, 9]], dtype=int32)

In [3]:
net = models.EnsembleNet(num_models=M, num_classes=num_classes)

In [4]:
# predictions etc before training
net.single_result = False
print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[[-0.00060502 -0.02492469  0.01806218  0.0351722   0.00936958
    0.03283712  0.02363573  0.0255007   0.00064794 -0.0013887 ]
  [-0.00107526 -0.02720002  0.02295999  0.02994362  0.00158176
    0.02029154  0.01993767  0.02784456  0.00132418  0.003337  ]
  [ 0.00396852 -0.02832424  0.02437153  0.03219391 -0.00064286
    0.01604812  0.02979654  0.03123377  0.00260794  0.00651032]
  [-0.00140848 -0.02888088  0.02452599  0.02666255  0.00024076
    0.02799064  0.02278895  0.03351299  0.00549321  0.00242322]]

 [[ 0.029436    0.00570479  0.0025988  -0.03024629 -0.01737218
    0.00398215 -0.04979104  0.03857547 -0.0251733   0.00706351]
  [ 0.03239074  0.00062699  0.00475765 -0.0314671  -0.01216873
    0.0089255  -0.03798135  0.03374011 -0.03052855 -0.00420806]
  [ 0.03099996 -0.00027723  0.00685071 -0.02714815 -0.01296634
    0.00968622 -0.0363331   0.03101299 -0.03117795 -0.00611702]
  [ 0.02973027  0.00409792  0.00167987 -0.02920377 -0.01169661
    0.01441652 -0.04370023  0.03600362 -0.0257

run training

In [5]:
from objax.functional.loss import cross_entropy_logits_sparse

# setup loss fn and optimiser
def cross_entropy(imgs, labels):   
    net.single_result = False
    logits = net.logits(imgs).reshape((M*B, num_classes))
    labels = labels.reshape((M*B,))
    return jnp.mean(cross_entropy_logits_sparse(logits, labels))

gradient_loss = objax.GradValues(cross_entropy, net.vars())
optimiser = objax.optimizer.Adam(net.vars())

# create jitted training step
learning_rate = 1e-3
def train_step(imgs, labels):
    grads, loss = gradient_loss(imgs, labels)
    optimiser(learning_rate, grads)
    return loss
train_step = objax.Jit(train_step,
                       gradient_loss.vars() + optimiser.vars())

In [6]:
for i in range(20):    
    print(i, train_step(imgs, labels))

0 [DeviceArray(2.309785, dtype=float32)]
1 [DeviceArray(2.286853, dtype=float32)]
2 [DeviceArray(2.2610347, dtype=float32)]
3 [DeviceArray(2.2312253, dtype=float32)]
4 [DeviceArray(2.1942642, dtype=float32)]
5 [DeviceArray(2.1467886, dtype=float32)]
6 [DeviceArray(2.0853095, dtype=float32)]
7 [DeviceArray(2.0025678, dtype=float32)]
8 [DeviceArray(1.8818042, dtype=float32)]
9 [DeviceArray(1.7152119, dtype=float32)]
10 [DeviceArray(1.5230902, dtype=float32)]
11 [DeviceArray(1.3439355, dtype=float32)]
12 [DeviceArray(1.2068758, dtype=float32)]
13 [DeviceArray(1.1142126, dtype=float32)]
14 [DeviceArray(1.0544429, dtype=float32)]
15 [DeviceArray(1.023199, dtype=float32)]
16 [DeviceArray(0.9841548, dtype=float32)]
17 [DeviceArray(0.9533998, dtype=float32)]
18 [DeviceArray(0.93385005, dtype=float32)]
19 [DeviceArray(0.9242128, dtype=float32)]


In [9]:
# predictions etc after training
# single result mode
net.single_result = False

print("labels")
print(labels)

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))

print("predictions")
print(net.predict(imgs))

labels
[[0 0 6 0]
 [6 3 0 9]]
[[[19.159555   -0.86604476 -0.75640136  0.5455518  -6.518797
    8.060124   18.033514    2.9550166   4.7997613  -7.1923504 ]
  [19.040037   -0.86604476 -0.7461923   0.5716469  -6.459628
    7.9231243  17.953642    2.9516416   4.798876   -7.168936  ]
  [18.957855   -0.91597223 -0.75199693  0.5271764  -6.4322305
    7.925053   18.044647    3.1206622   4.7976646  -7.234948  ]
  [19.241756   -0.8600795  -0.75236124  0.55760574 -6.551737
    8.029525   18.101171    2.9622502   4.848707   -7.201585  ]]

 [[ 2.784457   -0.47082624 -3.8567884   2.8172765  -2.7583883
   -1.6924964   3.5291965  -3.9809694  -2.883482    3.234063  ]
  [ 2.869034   -0.3944275  -3.7418451   2.962      -2.695875
   -1.7287304   3.2085233  -3.8212414  -2.8644915   3.241995  ]
  [ 2.9364972  -0.44032833 -3.6626866   2.7597113  -2.5655537
   -1.8206162   2.984578   -3.7365205  -2.755283    3.205778  ]
  [ 2.8760614  -0.33751926 -3.9145925   2.7517211  -2.7939067
   -1.8161858   3.0897684  -

In [None]:
# predictions etc after training
# one result per model
net.single_result = False

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))