In [1]:
import models
import numpy as np
import jax.numpy as jnp
import objax

create some dummy data

In [2]:
B = 4
imgs = jnp.array(np.random.uniform(0, 1, (B, 64, 64, 3)))
labels = jnp.array(np.random.randint(0, 10, (B,)))

In [3]:
labels

DeviceArray([7, 9, 9, 6], dtype=int32)

## non ensemble net
create a non ensemble net and train it

In [4]:
net = models.NonEnsembleNet(num_classes=10)

In [5]:
# predictions etc before training

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[-0.00987746  0.00590655  0.11114901 -0.06037663 -0.07429887 -0.07786602
   0.01010637  0.02298879 -0.11467976 -0.06140005]
 [-0.00063246 -0.11024021  0.09258822 -0.04025409 -0.13343385 -0.2358858
   0.03928605  0.08934309 -0.12842166 -0.17080614]
 [-0.03406183 -0.04649381  0.06919234 -0.08676191 -0.1077292  -0.1698364
   0.02133491 -0.01074979 -0.14941065 -0.07346772]
 [-0.0427481  -0.03550828  0.06617799 -0.0595531  -0.11695981 -0.20628926
   0.12021049  0.03066423 -0.19432637 -0.11461633]]
[[0.1  0.1  0.11 0.1  0.09 0.09 0.1  0.1  0.09 0.1 ]
 [0.11 0.09 0.12 0.1  0.09 0.08 0.11 0.12 0.09 0.09]
 [0.1  0.1  0.11 0.1  0.09 0.09 0.11 0.1  0.09 0.1 ]
 [0.1  0.1  0.11 0.1  0.09 0.09 0.12 0.11 0.09 0.09]]
[2 2 2 6]


In [6]:
from objax.functional.loss import cross_entropy_logits_sparse

# setup loss fn and optimiser
def cross_entropy(imgs, labels):
    logits = net.logits(imgs)
    return jnp.mean(cross_entropy_logits_sparse(logits, labels))

gradient_loss = objax.GradValues(cross_entropy, net.vars())
optimiser = objax.optimizer.Adam(net.vars())

# create jitted training step
learning_rate = 1e-3
def train_step(imgs, labels):
    grads, loss = gradient_loss(imgs, labels)
    optimiser(learning_rate, grads)
    return loss
train_step = objax.Jit(train_step,
                       gradient_loss.vars() + optimiser.vars())

In [7]:
for i in range(20):    
    print(i, train_step(imgs, labels))     

0 [DeviceArray(2.2820106, dtype=float32)]
1 [DeviceArray(1.8173239, dtype=float32)]
2 [DeviceArray(1.3605423, dtype=float32)]
3 [DeviceArray(1.0651532, dtype=float32)]
4 [DeviceArray(0.75211626, dtype=float32)]
5 [DeviceArray(0.5546689, dtype=float32)]
6 [DeviceArray(0.57570505, dtype=float32)]
7 [DeviceArray(0.30007482, dtype=float32)]
8 [DeviceArray(0.26520765, dtype=float32)]
9 [DeviceArray(0.1556958, dtype=float32)]
10 [DeviceArray(0.13710928, dtype=float32)]
11 [DeviceArray(0.05432296, dtype=float32)]
12 [DeviceArray(0.02429605, dtype=float32)]
13 [DeviceArray(0.02594113, dtype=float32)]
14 [DeviceArray(0.02457786, dtype=float32)]
15 [DeviceArray(0.01280928, dtype=float32)]
16 [DeviceArray(0.00496435, dtype=float32)]
17 [DeviceArray(0.00184798, dtype=float32)]
18 [DeviceArray(0.00085592, dtype=float32)]
19 [DeviceArray(0.00059962, dtype=float32)]


In [8]:
# predictions etc after training

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[-11.260939    -9.178541    -5.797422     1.139078    -2.6345367
   -8.705333     3.0127332   14.504839    -0.48101178   7.1435156 ]
 [-13.178496   -10.179994    -5.2081566    0.09942062  -0.45658046
   -6.731013    -2.3418913    6.1493683    3.306394    14.688072  ]
 [-13.189714   -10.067861    -5.499957     0.2943295   -0.80111545
   -6.7577844   -2.3381944    5.942384     3.381113    14.461624  ]
 [ -8.911894    -8.131757    -5.3439293    0.3700996   -3.8190906
   -9.2558155   11.9163265    5.293655    -1.2534811    1.8530917 ]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]]
[7 9 9 6]


## create ensemble net

In [9]:
net = models.EnsembleNet(num_models=2, num_classes=10)

In [10]:
# predictions etc before training
# single result mode
net.single_result = True

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[ 0.03365174 -0.02876222  0.02771349  0.0011859  -0.02194074  0.03224387
  -0.01268185  0.06088937 -0.02741139  0.00741585]
 [ 0.02965121 -0.02007229  0.02645916 -0.00013432 -0.01158906  0.02800233
  -0.0206742   0.06373619 -0.02560225  0.0036592 ]
 [ 0.02518359 -0.03246357  0.03012501  0.00206306 -0.01214064  0.03849406
  -0.01365265  0.06651808 -0.03627107  0.00289541]
 [ 0.0348612  -0.0301028   0.02714048 -0.0068539  -0.01385327  0.04159915
  -0.02380633  0.06953844 -0.01681299  0.00606694]]
[[0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1 ]
 [0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1 ]
 [0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1 ]
 [0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1 ]]
[7 7 7 7]


In [11]:
# predictions etc before training
# one result per model
net.single_result = False

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[[ 0.00075775 -0.0299944   0.02346599  0.02971933 -0.00609219
    0.02266714  0.01816031  0.0254093   0.00248097  0.00829709]
  [-0.00354719 -0.02591672  0.01902036  0.03000627  0.00317005
    0.01907494  0.02232362  0.02507329  0.0099296   0.00489009]
  [-0.00283656 -0.0291962   0.02323768  0.03080092  0.00151774
    0.01911904  0.01703912  0.02529165 -0.00227295  0.00560843]
  [-0.00239164 -0.0211306   0.01940439  0.02698449 -0.00194042
    0.03024101  0.01830228  0.03399596  0.00843064  0.00085138]]

 [[ 0.03289399  0.00123218  0.00424749 -0.02853342 -0.01584855
    0.00957673 -0.03084216  0.03548007 -0.02989236 -0.00088124]
  [ 0.0331984   0.00584443  0.0074388  -0.0301406  -0.01475911
    0.00892739 -0.04299783  0.03866289 -0.03553185 -0.00123089]
  [ 0.02802015 -0.00326737  0.00688733 -0.02873786 -0.01365838
    0.01937502 -0.03069177  0.04122643 -0.03399812 -0.00271302]
  [ 0.03725284 -0.00897219  0.00773609 -0.03383839 -0.01191286
    0.01135815 -0.04210861  0.03554248 -0.0252

run training

In [12]:
from objax.functional.loss import cross_entropy_logits_sparse

# setup loss fn and optimiser
net.single_result = True
def cross_entropy(imgs, labels):    
    logits = net.logits(imgs)
    return jnp.mean(cross_entropy_logits_sparse(logits, labels))

gradient_loss = objax.GradValues(cross_entropy, net.vars())
optimiser = objax.optimizer.Adam(net.vars())

# create jitted training step
learning_rate = 1e-3
def train_step(imgs, labels):
    grads, loss = gradient_loss(imgs, labels)
    optimiser(learning_rate, grads)
    return loss
train_step = objax.Jit(train_step,
                       gradient_loss.vars() + optimiser.vars())

In [13]:
labels

DeviceArray([7, 9, 9, 6], dtype=int32)

In [14]:
for i in range(20):    
    print(i, train_step(imgs, labels))

0 [DeviceArray(2.2997265, dtype=float32)]
1 [DeviceArray(2.234067, dtype=float32)]
2 [DeviceArray(2.1484475, dtype=float32)]
3 [DeviceArray(2.0272722, dtype=float32)]
4 [DeviceArray(1.8524725, dtype=float32)]
5 [DeviceArray(1.6296611, dtype=float32)]
6 [DeviceArray(1.4355035, dtype=float32)]
7 [DeviceArray(1.2841322, dtype=float32)]
8 [DeviceArray(1.1069925, dtype=float32)]
9 [DeviceArray(1.0841106, dtype=float32)]
10 [DeviceArray(1.0210012, dtype=float32)]
11 [DeviceArray(0.9948845, dtype=float32)]
12 [DeviceArray(0.95197046, dtype=float32)]
13 [DeviceArray(0.9241742, dtype=float32)]
14 [DeviceArray(0.8746195, dtype=float32)]
15 [DeviceArray(0.8463354, dtype=float32)]
16 [DeviceArray(0.7521789, dtype=float32)]
17 [DeviceArray(0.7270386, dtype=float32)]
18 [DeviceArray(0.62908864, dtype=float32)]
19 [DeviceArray(0.62393093, dtype=float32)]


In [15]:
# predictions etc after training
# single result mode
net.single_result = True

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[-7.3012743   0.05483365 -1.2917566  -0.4802773  -4.422536   -0.1108427
   9.259617    9.656329   -6.931133    9.926528  ]
 [-7.3997726   0.4566697  -1.7491815  -1.0619249  -4.6235304  -0.4573785
   8.163191    7.709463   -7.3533382  10.369416  ]
 [-7.5261555   0.4267962  -1.7695405  -1.0861535  -4.653118   -0.4121461
   8.220253    7.8918076  -7.5399246  10.548212  ]
 [-7.5030556   0.01869178 -1.4042704  -0.53091216 -4.267513   -0.18526506
   9.936641    8.795258   -6.8832426   9.636349  ]]
[[0.   0.   0.   0.   0.   0.   0.23 0.34 0.   0.44]
 [0.   0.   0.   0.   0.   0.   0.09 0.06 0.   0.85]
 [0.   0.   0.   0.   0.   0.   0.08 0.06 0.   0.86]
 [0.   0.   0.   0.   0.   0.   0.49 0.16 0.   0.36]]
[9 9 9 6]


In [16]:
# predictions etc after training
# one result per model
net.single_result = False

print(net.logits(imgs))
print(np.around(net.predict_proba(imgs), 2))
print(net.predict(imgs))

[[[-2.3167975 -2.1795099  2.5778768  2.9290826 -1.5126336  2.238614
    4.6613755  5.9442506 -3.9882505  3.2060611]
  [-2.371347  -1.7866195  2.1497276  2.4553292 -1.6442361  1.9316167
    3.4391456  4.22048   -4.3927827  3.5414066]
  [-2.4113905 -1.8348057  2.1896472  2.4806967 -1.6677935  1.991399
    3.446853   4.315103  -4.5160894  3.6170614]
  [-2.3165443 -2.1029356  2.4102602  2.936252  -1.5291467  2.0461354
    4.8045607  5.5739236 -4.027545   3.180185 ]]

 [[-4.9844766  2.2343435 -3.8696334 -3.40936   -2.9099026 -2.3494568
    4.598242   3.7120788 -2.9428823  6.7204666]
  [-5.0284257  2.2432892 -3.898909  -3.517254  -2.9792943 -2.3889952
    4.724045   3.4889834 -2.9605553  6.8280096]
  [-5.1147647  2.261602  -3.9591877 -3.5668502 -2.9853249 -2.4035451
    4.7734003  3.5767043 -3.0238354  6.9311504]
  [-5.1865115  2.1216273 -3.8145306 -3.4671643 -2.7383661 -2.2314005
    5.13208    3.221334  -2.8556974  6.456164 ]]]
[[[0.   0.   0.02 0.03 0.   0.02 0.19 0.69 0.   0.04]
  [0.   