In [1]:
import models
import numpy as np
import jax.numpy as jnp
import objax

create some dummy data

In [2]:
B = 4
imgs = jnp.array(np.random.uniform(0, 1, (B, 64, 64, 3)))
labels = jnp.array(np.random.randint(0, 10, (B,)))

In [3]:
labels

DeviceArray([9, 9, 3, 6], dtype=int32)

## non ensemble net
create a non ensemble net and train it

In [4]:
net = models.NonEnsembleNet(num_classes=10)

In [5]:
# predictions etc before training

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[ 0.05931298 -0.01358349  0.04821177 -0.08526853 -0.04139313 -0.07316041
   0.05848585  0.00936744 -0.13244155 -0.11536108]
 [-0.04136755 -0.04382023  0.07807877 -0.0895777  -0.06413972 -0.09614965
   0.05369993 -0.03820778 -0.09625124 -0.10016692]
 [-0.09322944 -0.07778878  0.05125523 -0.0582838  -0.07931533 -0.17146161
  -0.00603996  0.02536846 -0.09874791 -0.07051281]
 [-0.07716401 -0.07319045  0.09903803 -0.13362144 -0.10105438 -0.12014139
   0.14281966  0.02960072 -0.16343127 -0.15003654]]
[[0.11 0.1  0.11 0.09 0.1  0.1  0.11 0.1  0.09 0.09]
 [0.1  0.1  0.11 0.1  0.1  0.09 0.11 0.1  0.09 0.09]
 [0.1  0.1  0.11 0.1  0.1  0.09 0.11 0.11 0.1  0.1 ]
 [0.1  0.1  0.12 0.09 0.09 0.09 0.12 0.11 0.09 0.09]]
[0 2 2 6]


In [6]:
from objax.functional.loss import cross_entropy_logits_sparse

# setup loss fn and optimiser
def cross_entropy(imgs, labels):
    logits = net.logits(imgs)
    return jnp.mean(cross_entropy_logits_sparse(logits, labels))

gradient_loss = objax.GradValues(cross_entropy, net.vars())
optimiser = objax.optimizer.Adam(net.vars())

# create jitted training step
learning_rate = 1e-3
def train_step(imgs, labels):
    grads, loss = gradient_loss(imgs, labels)
    optimiser(learning_rate, grads)
    return loss
train_step = objax.Jit(train_step,
                       gradient_loss.vars() + optimiser.vars())

In [7]:
for i in range(20):    
    print(i, train_step(imgs, labels))     

0 [DeviceArray(2.2919455, dtype=float32)]
1 [DeviceArray(1.873941, dtype=float32)]
2 [DeviceArray(1.4620001, dtype=float32)]
3 [DeviceArray(1.1369451, dtype=float32)]
4 [DeviceArray(0.7597535, dtype=float32)]
5 [DeviceArray(0.69046897, dtype=float32)]
6 [DeviceArray(0.461321, dtype=float32)]
7 [DeviceArray(0.3427558, dtype=float32)]
8 [DeviceArray(0.21064234, dtype=float32)]
9 [DeviceArray(0.1665225, dtype=float32)]
10 [DeviceArray(0.05688453, dtype=float32)]
11 [DeviceArray(0.04338741, dtype=float32)]
12 [DeviceArray(0.03254128, dtype=float32)]
13 [DeviceArray(0.01887846, dtype=float32)]
14 [DeviceArray(0.00653529, dtype=float32)]
15 [DeviceArray(0.00273657, dtype=float32)]
16 [DeviceArray(0.00191426, dtype=float32)]
17 [DeviceArray(0.00160408, dtype=float32)]
18 [DeviceArray(0.00154567, dtype=float32)]
19 [DeviceArray(0.00191951, dtype=float32)]


In [8]:
# predictions etc after training

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[-14.160842   -14.632252    -9.272084     7.4294176   -3.2836404
   -7.300902     5.7107778   -0.31011692   0.5309466   14.6887865 ]
 [-14.213045   -13.986172    -9.046066     7.701859    -2.982193
   -7.13353      4.4228563   -0.03412117   0.9953393   15.0208645 ]
 [-15.175803   -14.582171    -8.989858    14.487215    -7.0131364
  -10.25056      9.682943     1.4085907    0.04058924   3.5313115 ]
 [-13.3289795  -13.750578    -8.279888     7.632069    -7.2934456
  -10.801613    21.055822     0.2977613   -2.3938494    0.73247755]]
[[0.   0.   0.   0.   0.   0.   0.   0.   0.   1.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   1.  ]
 [0.   0.   0.   0.99 0.   0.   0.01 0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.   1.   0.   0.   0.  ]]
[9 9 3 6]


## create ensemble net

In [10]:
net = models.EnsembleNet(num_models=2, num_classes=10)

In [13]:
# predictions etc before training
# single result mode

print(net.logits(imgs, single_result=True))
print(np.around(net.predict_proba(imgs, single_result=True), 2))
print(net.predict(imgs, single_result=True))

[[ 3.4129202e-02 -2.4621822e-02  3.3511866e-02  5.1432177e-03
  -9.0999128e-03  3.7883043e-02 -1.9921180e-02  6.5821514e-02
  -4.3066598e-02  2.6647185e-03]
 [ 2.9673504e-02 -1.4374586e-02  2.5476806e-02  5.6982096e-03
  -7.7494602e-03  3.6203891e-02 -1.2926057e-02  6.4915992e-02
  -2.5038097e-02 -5.1552895e-05]
 [ 3.1153135e-02 -2.4204625e-02  2.9672267e-02  2.0377897e-04
  -2.2542020e-03  3.8102139e-02 -2.6493456e-02  5.7798468e-02
  -3.0413117e-02 -3.2157935e-03]
 [ 2.3621442e-02 -2.4489410e-02  2.2022069e-02 -8.6232498e-03
  -1.1887665e-02  4.9225710e-02 -1.1554819e-02  6.0569778e-02
  -3.1225886e-02 -2.1867373e-03]]
[[0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.09 0.1 ]
 [0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1 ]
 [0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1 ]
 [0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1 ]]
[7 7 7 7]


In [12]:
# predictions etc before training
# one result per model
print(net.logits(imgs, single_result=False))
print(np.around(net.predict_proba(imgs, single_result=False), 2))
print(net.predict(imgs, single_result=False))

[[[-2.73406948e-03 -2.59164907e-02  2.79843397e-02  2.92396769e-02
    2.06236821e-03  2.87132636e-02  1.77456252e-02  2.99331434e-02
   -4.69963811e-03  2.77550891e-03]
  [ 2.76120147e-03 -2.06806511e-02  1.75689496e-02  3.09492834e-02
    7.07814330e-03  2.40985043e-02  2.20532641e-02  2.57022865e-02
    7.64222536e-03  1.94611121e-03]
  [-3.65521060e-03 -2.37140693e-02  2.03380361e-02  2.71367505e-02
    7.05849379e-06  2.68153306e-02  1.28787383e-02  1.88253522e-02
   -1.99775398e-03  6.41586445e-03]
  [-4.79126209e-03 -2.75755059e-02  1.87845957e-02  2.63013616e-02
    4.21540532e-03  3.04049253e-02  1.97929852e-02  2.44712159e-02
    2.05561146e-03 -3.38256359e-05]]

 [[ 3.68632711e-02  1.29466970e-03  5.52752614e-03 -2.40964592e-02
   -1.11622810e-02  9.16977972e-03 -3.76668051e-02  3.58883739e-02
   -3.83669622e-02 -1.10790446e-04]
  [ 2.69123018e-02  6.30606525e-03  7.90785532e-03 -2.52510738e-02
   -1.48276035e-02  1.21053876e-02 -3.49793211e-02  3.92137058e-02
   -3.26803215

run training

In [14]:
from objax.functional.loss import cross_entropy_logits_sparse

# setup loss fn and optimiser
def cross_entropy(imgs, labels):
    logits = net.logits(imgs, single_result=True)
    return jnp.mean(cross_entropy_logits_sparse(logits, labels))

gradient_loss = objax.GradValues(cross_entropy, net.vars())
optimiser = objax.optimizer.Adam(net.vars())

# create jitted training step
learning_rate = 1e-3
def train_step(imgs, labels):
    grads, loss = gradient_loss(imgs, labels)
    optimiser(learning_rate, grads)
    return loss
train_step = objax.Jit(train_step,
                       gradient_loss.vars() + optimiser.vars())

In [18]:
labels

DeviceArray([9, 9, 3, 6], dtype=int32)

In [19]:
for i in range(20):    
    print(i, train_step(imgs, labels))

0 [DeviceArray(0.731488, dtype=float32)]
1 [DeviceArray(0.57339025, dtype=float32)]
2 [DeviceArray(0.59187627, dtype=float32)]
3 [DeviceArray(0.5514784, dtype=float32)]
4 [DeviceArray(0.45597744, dtype=float32)]
5 [DeviceArray(0.4815061, dtype=float32)]
6 [DeviceArray(0.39145184, dtype=float32)]
7 [DeviceArray(0.3601992, dtype=float32)]
8 [DeviceArray(0.3616085, dtype=float32)]
9 [DeviceArray(0.27379584, dtype=float32)]
10 [DeviceArray(0.27995467, dtype=float32)]
11 [DeviceArray(0.23937178, dtype=float32)]
12 [DeviceArray(0.19763088, dtype=float32)]
13 [DeviceArray(0.19724083, dtype=float32)]
14 [DeviceArray(0.1548667, dtype=float32)]
15 [DeviceArray(0.14162683, dtype=float32)]
16 [DeviceArray(0.12479448, dtype=float32)]
17 [DeviceArray(0.0987072, dtype=float32)]
18 [DeviceArray(0.09665203, dtype=float32)]
19 [DeviceArray(0.07891273, dtype=float32)]


In [20]:
# predictions etc after training
# single result mode

print(net.logits(imgs, single_result=True))
print(np.around(net.predict_proba(imgs, single_result=True), 2))
print(net.predict(imgs, single_result=True))

[[-6.603551    0.37143588 -4.534133    7.069108   -7.405176   -0.4433905
   5.1664753  -5.6680164  -9.199497   10.220631  ]
 [-6.4254494   0.29847622 -4.3883266   7.1322136  -7.276828   -0.50329196
   4.7432365  -5.5368357  -8.891392    9.909617  ]
 [-7.9383154  -1.6123307  -2.475781   11.297846   -6.965103   -0.8226026
   6.001742   -4.5295353  -8.8568325   9.082139  ]
 [-8.067325   -1.2837391  -4.15346     6.953418   -6.146677   -0.5298461
  12.060303   -3.4403589  -7.8452845   7.6725025 ]]
[[0.   0.   0.   0.04 0.   0.   0.01 0.   0.   0.95]
 [0.   0.   0.   0.06 0.   0.   0.01 0.   0.   0.94]
 [0.   0.   0.   0.9  0.   0.   0.   0.   0.   0.1 ]
 [0.   0.   0.   0.01 0.   0.   0.98 0.   0.   0.01]]
[9 9 3 6]


In [21]:
# predictions etc after training
# one result per model

print(net.logits(imgs, single_result=False))
print(np.around(net.predict_proba(imgs, single_result=False), 2))
print(net.predict(imgs, single_result=False))

[[[-2.3221781  -2.0916784   2.7177093   8.96912    -2.360869
    1.7116941   0.28142864 -2.9245174  -5.511493    6.116011  ]
  [-2.2849324  -2.1026242   2.628871    8.842983   -2.2701795
    1.5917121   0.2608282  -2.9614723  -5.3426733   5.90769   ]
  [-3.7604282  -4.130998    4.171939   12.305879   -2.217454
    1.0168165   2.0247235  -2.6345284  -5.7080894   5.525353  ]
  [-3.049157   -3.631826    3.3355174   9.654741   -2.0997205
    1.0848416   4.8208237  -0.21601519 -4.5838995   4.615962  ]]

 [[-4.2813725   2.4631143  -7.2518425  -1.900012   -5.044307
   -2.1550846   4.8850465  -2.7434988  -3.6880038   4.10462   ]
  [-4.1405168   2.4011004  -7.0171976  -1.7107697  -5.006648
   -2.095004    4.4824085  -2.5753632  -3.5487187   4.0019274 ]
  [-4.1778874   2.5186675  -6.64772    -1.008033   -4.747649
   -1.8394191   3.9770184  -1.895007   -3.148743    3.5567863 ]
  [-5.0181675   2.3480868  -7.4889774  -2.7013233  -4.0469565
   -1.6146877   7.239479   -3.2243438  -3.2613852   3.05654