## MNIST Training with MXNet and Gluon - Local

MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial will show how to train and test an MNIST model on SageMaker using MXNet and the Gluon API.



In [None]:
import mnist
import mxnet as mx
from mxnet import gluon,autograd
import time

## Set the hyperparameters

In [None]:
batch_size =100
epochs = 3
learning_rate = 0.01
momentum = 0.9
log_interval = 100

## Download training and test data

In [None]:
train_data = mnist.get_train_data('./data/train', batch_size)
val_data = mnist.get_val_data('./data/test', batch_size)

## Set the network

In [None]:
ctx = mx.gpu()
# define the network
net = mnist.define_network()

# Collect all parameters from net and its children, then initialize them.
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
# Trainer is for updating parameters with gradient.

trainer = gluon.Trainer(net.collect_params(), 'sgd',{'learning_rate': learning_rate, 'momentum': momentum})
metric = mx.metric.Accuracy()
loss = gluon.loss.SoftmaxCrossEntropyLoss()

In [None]:
# Train the model
for epoch in range(epochs):
    # reset data iterator and metric at begining of epoch.
    metric.reset()
    btic = time.time()
    for i, (data, label) in enumerate(train_data):
        # Copy data to ctx if necessary
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        # Start recording computation graph with record() section.
        # Recorded graphs can then be differentiated with backward.
        with autograd.record():
            output = net(data)
            L = loss(output, label)
            L.backward()
        # take a gradient step with batch_size equal to data.shape[0]
        trainer.step(data.shape[0])
        # update metric at last.
        metric.update([label], [output])

        if i % log_interval == 0 and i > 0:
            name, acc = metric.get()
            print('[Epoch %d Batch %d] Training: %s=%f, %f samples/s' %
                  (epoch, i, name, acc, batch_size / (time.time() - btic)))

        btic = time.time()

    name, acc = metric.get()
    print('[Epoch %d] Training: %s=%f' % (epoch, name, acc))

    name, val_acc = mnist.test(ctx, net, val_data)
    print('[Epoch %d] Validation: %s=%f' % (epoch, name, val_acc))