Taken 90% straight from https://mxnet.incubator.apache.org/tutorials/python/mnist.html

In [1]:
import mxnet as mx
import boto3

batch_size = 100

## Load Data

In [2]:
mnist = mx.test_utils.get_mnist()

## Split Train/Test

In [3]:
ntrain = int(mnist['train_data'].shape[0]*0.8)
train_iter = mx.io.NDArrayIter(mnist['train_data'][:ntrain], mnist['train_label'][:ntrain], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['train_data'][ntrain:], mnist['train_label'][ntrain:], batch_size)

## Define Network

In [4]:
data = mx.sym.var('data')
# first conv layer
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten = mx.sym.flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')

## Train Model

In [5]:
# create a trainable module on GPU 0
lenet_model = mx.mod.Module(symbol=lenet, context=mx.gpu())
# train with the same
lenet_model.fit(train_iter,
                eval_data=val_iter,
                optimizer='sgd',
                optimizer_params={'learning_rate':0.1},
                eval_metric='acc',
                batch_end_callback = mx.callback.Speedometer(batch_size, 100),
                num_epoch=10)

## Evaluate Model

In [6]:
# predict accuracy for lenet
acc = mx.metric.Accuracy()
lenet_model.score(val_iter, acc)
print(acc)

EvalMetric: {'accuracy': 0.98466666666666669}


## Save Model

In [8]:
lenet.save('./mnist_symbol.mxnet')
lenet_model.save_params('./mnist_module.mxnet')

s3 = boto3.client('s3')
s3.upload_file(
    './mnist_symbol.mxnet',
    'jakechenawstemp',
    'mnist_symbol.mxnet'
)
s3.upload_file(
    './mnist_module.mxnet',
    'jakechenawstemp',
    'mnist_module.mxnet'
)