# MNIST with CNN using MXNet

Code mostly from http://mxnet.io/tutorials/python/mnist.html

## Setup

In [1]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./.data/', one_hot=True)

Extracting ./.data/train-images-idx3-ubyte.gz
Extracting ./.data/train-labels-idx1-ubyte.gz
Extracting ./.data/t10k-images-idx3-ubyte.gz
Extracting ./.data/t10k-labels-idx1-ubyte.gz


In [2]:
import time
import mxnet as mx # for some reason, loading mxnet before the dataset from tensorflow crashes the kernel
import logging
logging.getLogger().setLevel(logging.DEBUG)

In [3]:
X_train = mnist.train.images
X_test = mnist.train.images
X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
X_test = X_test.reshape(X_train.shape[0], 1, 28, 28)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

y_train = mnist.train.labels
y_test = mnist.train.labels

## Define network

In [8]:
data = mx.symbol.Variable('data')
# first conv layer
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten = mx.sym.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')

## Train

In [9]:
mod = mx.mod.Module(lenet, context=mx.gpu(0))

t0 = time.time()
mod.fit(mx.io.NDArrayIter(X_train, label=y_train, batch_size=100, shuffle=True),
        num_epoch=10,
        optimizer='sgd',
        batch_end_callback = mx.callback.Speedometer(100, 200)
       )

print "training time = {}".format(time.time() - t0)

INFO:root:Epoch[0] Batch [200]	Speed: 19910.40 samples/sec	Train-accuracy=0.900000
INFO:root:Epoch[0] Batch [400]	Speed: 23027.34 samples/sec	Train-accuracy=0.900000
INFO:root:Epoch[0] Train-accuracy=0.900000
INFO:root:Epoch[0] Time cost=2.544
INFO:root:Epoch[1] Batch [200]	Speed: 22951.81 samples/sec	Train-accuracy=0.900000
INFO:root:Epoch[1] Batch [400]	Speed: 22985.78 samples/sec	Train-accuracy=0.900000
INFO:root:Epoch[1] Train-accuracy=0.900000
INFO:root:Epoch[1] Time cost=2.396
INFO:root:Epoch[2] Batch [200]	Speed: 22966.49 samples/sec	Train-accuracy=0.900000
INFO:root:Epoch[2] Batch [400]	Speed: 22911.84 samples/sec	Train-accuracy=0.900000
INFO:root:Epoch[2] Train-accuracy=0.900000
INFO:root:Epoch[2] Time cost=2.400
INFO:root:Epoch[3] Batch [200]	Speed: 22959.74 samples/sec	Train-accuracy=0.900000
INFO:root:Epoch[3] Batch [400]	Speed: 23002.86 samples/sec	Train-accuracy=0.900000
INFO:root:Epoch[3] Train-accuracy=0.900000
INFO:root:Epoch[3] Time cost=2.396
INFO:root:Epoch[4] Batch

KeyboardInterrupt: 

## Evaluate

In [None]:
metric = mx.metric.Accuracy()
mod.score(mx.io.NDArrayIter(X_test, label=y_test, batch_size=100), metric)