## Learning MNIST with the LeNet CNN

First, let's download the data set.

In [None]:
!wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
!gzip -d train*.gz t10k*.gz

In [None]:
import mxnet as mx
import logging

logging.basicConfig(level=logging.INFO)

MXNet provides a convenient iterator for MNIST. We use it to build the training and the validation iterators.

In [None]:
nb_epochs = 25

train_iter = mx.io.MNISTIter(shuffle=True)
val_iter = mx.io.MNISTIter(image="./t10k-images-idx3-ubyte", label="./t10k-labels-idx1-ubyte")

We build the LeNet network (http://yann.lecun.com/exdb/lenet/), replacing the tanh activation function with the ReLU function.

In [None]:
data = mx.symbol.Variable('data')
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
relu1 = mx.sym.Activation(data=conv1, act_type="relu")
pool1 = mx.sym.Pooling(data=relu1, pool_type="max", kernel=(2,2), stride=(2,2))
conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
relu2 = mx.sym.Activation(data=conv2, act_type="relu")
pool2 = mx.sym.Pooling(data=relu2, pool_type="max", kernel=(2,2), stride=(2,2))
flatten = mx.sym.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
relu3 = mx.sym.Activation(data=fc1, act_type="relu")
fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=10)
lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')

In [None]:
mx.viz.plot_network(lenet)

Now, we need to:
- bind the model to the training set,
- initialize the parameters, i.e. set initial values for all weights,
- pick an optimizer and a learning rate, to adjust weights during backpropagation

In [None]:
#mod = mx.mod.Module(lenet)
mod = mx.mod.Module(lenet, context=mx.gpu(0))
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Xavier())
mod.init_optimizer('adam', optimizer_params=(('learning_rate', 0.1),))

Time to train!

In [None]:
mod.fit(train_iter, eval_data=val_iter, num_epoch=nb_epochs,
        batch_end_callback=mx.callback.Speedometer(128, 100))

In [None]:
mod.save_checkpoint("lenet", nb_epochs)

Let's measure validation accuracy.

In [None]:
metric = mx.metric.Accuracy()
mod.score(val_iter, metric)
print(metric.get())