# High-level MXNet MNIST Example

In [1]:
import os
import sys
import numpy as np
import mxnet as mx
from common.params import *
from common.utils import *

In [2]:
print(mx.__version__)
print(np.__version__)

0.10.0
1.13.1


In [3]:
def create_lenet():
    data = mx.symbol.Variable('data')
    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
    tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
    pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
    conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
    tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
    pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2)) 
    flatten = mx.symbol.Flatten(data=pool2)
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500) 
    tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")
    fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=N_CLASSES) 
    input_y = mx.symbol.Variable('softmax_label')  
    m = mx.symbol.SoftmaxOutput(data=fc2, label=input_y, name="softmax")
    return m

In [4]:
def init_model():
    if GPU:
        ctx = [mx.gpu(0)]
    else:
        ctx = mx.cpu()
    return mx.mod.Module(context = ctx, symbol = create_lenet())

In [5]:
%%time
# Data into format for library
x_train, x_test, y_train, y_test = mnist_for_library(channel_first=True)

CPU times: user 176 ms, sys: 240 ms, total: 416 ms
Wall time: 416 ms


In [6]:
%%time
# Initialise model
model = init_model()

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 1.56 ms


In [7]:
%%time
# Train model
model.fit(train_data = mx.io.NDArrayIter(x_train, y_train, batch_size=BATCHSIZE, shuffle=True),
          optimizer = 'sgd',
          optimizer_params = {'learning_rate':LR, 'momentum':MOMENTUM},
          eval_metric = 'acc',
          batch_end_callback = mx.callback.Speedometer(BATCHSIZE, 10), 
          num_epoch = EPOCHS)     

CPU times: user 1min 13s, sys: 16.5 s, total: 1min 29s
Wall time: 1min 6s


In [8]:
%%time
# Test model
acc = mx.metric.Accuracy()
model.score(mx.io.NDArrayIter(x_test, y_test, batch_size=BATCHSIZE), acc)
print("Accuracy ", acc)

Accuracy  EvalMetric: {'accuracy': 0.98911741214057503}
CPU times: user 304 ms, sys: 168 ms, total: 472 ms
Wall time: 417 ms
