## Learning MNIST with the LeNet CNN

First, let's download the data set.

In [None]:
!wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
!gzip -d train*.gz t10k*.gz

In [12]:
import mxnet as mx
import logging

logging.basicConfig(level=logging.INFO)

MXNet provides a convenient iterator for MNIST. We use it to build the training and the validation iterators.

In [13]:
nb_epochs = 25

train_iter = mx.io.MNISTIter(shuffle=True)
val_iter = mx.io.MNISTIter(image="./t10k-images-idx3-ubyte", label="./t10k-labels-idx1-ubyte")

We build the LeNet network (http://yann.lecun.com/exdb/lenet/), replacing the tanh activation function with the ReLU function.

In [14]:
data = mx.symbol.Variable('data')
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
relu1 = mx.sym.Activation(data=conv1, act_type="relu")
pool1 = mx.sym.Pooling(data=relu1, pool_type="max", kernel=(2,2), stride=(2,2))
conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
relu2 = mx.sym.Activation(data=conv2, act_type="relu")
pool2 = mx.sym.Pooling(data=relu2, pool_type="max", kernel=(2,2), stride=(2,2))
flatten = mx.sym.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
relu3 = mx.sym.Activation(data=fc1, act_type="relu")
fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=10)
lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')

Now, we need to:
- bind the model to the training set,
- initialize the parameters, i.e. set initial values for all weights,
- pick an optimizer and a learning rate, to adjust weights during backpropagation

In [15]:
#mod = mx.mod.Module(lenet)
mod = mx.mod.Module(lenet, context=mx.gpu(0))
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Xavier())
mod.init_optimizer('adam', optimizer_params=(('learning_rate', 0.1),))

Time to train!

In [16]:
mod.fit(train_iter, eval_data=val_iter, num_epoch=nb_epochs,
        batch_end_callback=mx.callback.Speedometer(128, 100))

  allow_missing=allow_missing, force_init=force_init)
INFO:root:Epoch[0] Batch [100]	Speed: 24162.63 samples/sec	accuracy=0.747447
INFO:root:Epoch[0] Batch [200]	Speed: 29782.38 samples/sec	accuracy=0.938047
INFO:root:Epoch[0] Batch [300]	Speed: 31277.74 samples/sec	accuracy=0.955391
INFO:root:Epoch[0] Batch [400]	Speed: 31242.05 samples/sec	accuracy=0.966797
INFO:root:Epoch[0] Train-accuracy=0.970266
INFO:root:Epoch[0] Time cost=2.064
INFO:root:Epoch[0] Validation-accuracy=0.976963
INFO:root:Epoch[1] Batch [100]	Speed: 31202.11 samples/sec	accuracy=0.974087
INFO:root:Epoch[1] Batch [200]	Speed: 31228.84 samples/sec	accuracy=0.974844
INFO:root:Epoch[1] Batch [300]	Speed: 31183.60 samples/sec	accuracy=0.977578
INFO:root:Epoch[1] Batch [400]	Speed: 31248.98 samples/sec	accuracy=0.981797
INFO:root:Epoch[1] Train-accuracy=0.979827
INFO:root:Epoch[1] Time cost=1.922
INFO:root:Epoch[1] Validation-accuracy=0.984375
INFO:root:Epoch[2] Batch [100]	Speed: 31173.24 samples/sec	accuracy=0.982751
I

INFO:root:Epoch[18] Batch [100]	Speed: 31160.43 samples/sec	accuracy=0.999613
INFO:root:Epoch[18] Batch [200]	Speed: 31333.70 samples/sec	accuracy=0.999609
INFO:root:Epoch[18] Batch [300]	Speed: 31233.93 samples/sec	accuracy=0.999141
INFO:root:Epoch[18] Batch [400]	Speed: 31332.51 samples/sec	accuracy=0.999453
INFO:root:Epoch[18] Train-accuracy=0.999883
INFO:root:Epoch[18] Time cost=1.918
INFO:root:Epoch[18] Validation-accuracy=0.991086
INFO:root:Epoch[19] Batch [100]	Speed: 31077.26 samples/sec	accuracy=0.999691
INFO:root:Epoch[19] Batch [200]	Speed: 31280.28 samples/sec	accuracy=0.999609
INFO:root:Epoch[19] Batch [300]	Speed: 31225.37 samples/sec	accuracy=0.999609
INFO:root:Epoch[19] Batch [400]	Speed: 31300.32 samples/sec	accuracy=0.999609
INFO:root:Epoch[19] Train-accuracy=0.999883
INFO:root:Epoch[19] Time cost=1.922
INFO:root:Epoch[19] Validation-accuracy=0.990785
INFO:root:Epoch[20] Batch [100]	Speed: 31210.39 samples/sec	accuracy=0.999768
INFO:root:Epoch[20] Batch [200]	Speed: 3

In [19]:
mod.save_checkpoint("lenet", nb_epochs)

INFO:root:Saved checkpoint to "lenet-0025.params"


Let's measure validation accuracy.

In [20]:
metric = mx.metric.Accuracy()
mod.score(val_iter, metric)
print(metric.get())

('accuracy', 0.99128605769230771)
