In [None]:
import mxnet as mx

# Select a fixed random seed for reproducibility
mx.random.seed(42)

def data_xform(data):
    """Move channel axis to the beginning, cast to float32, and normalize to [0, 1]."""
    return nd.moveaxis(data, 2, 0).astype('float32') / 255

train_data = mx.gluon.data.vision.MNIST(train=True).transform_first(data_xform)
val_data = mx.gluon.data.vision.MNIST(train=False).transform_first(data_xform)

In [None]:
batch_size = 100
train_loader = mx.gluon.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
val_loader = mx.gluon.data.DataLoader(val_data, shuffle=False, batch_size=batch_size)

In [None]:
from __future__ import print_function  # only relevant for Python 2
import mxnet as mx
from mxnet import nd, gluon, autograd
from mxnet.gluon import nn

In [None]:
net = nn.HybridSequential(prefix='MLP_')
with net.name_scope():
    net.add(
        nn.Flatten(),
        nn.Dense(128, activation='relu'),
        nn.Dense(64, activation='relu'),
        nn.Dense(10, activation=None)  # loss function includes softmax already, see below
    )

In [None]:
ctx = mx.gpu(0) if mx.context.num_gpus() > 0 else mx.cpu(0)
net.initialize(mx.init.Xavier(), ctx=ctx)

In [None]:
trainer = gluon.Trainer(
    params=net.collect_params(),
    optimizer='sgd',
    optimizer_params={'learning_rate': 0.04},
)

In [None]:
metric = mx.metric.Accuracy()
loss_function = gluon.loss.SoftmaxCrossEntropyLoss()

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        # Possibly copy inputs and labels to the GPU
        inputs = inputs.as_in_context(ctx)
        labels = labels.as_in_context(ctx)

        # The forward pass and the loss computation need to be wrapped
        # in a `record()` scope to make sure the computational graph is
        # recorded in order to automatically compute the gradients
        # during the backward pass.
        with autograd.record():
            outputs = net(inputs)
            loss = loss_function(outputs, labels)

        # Compute gradients by backpropagation and update the evaluation
        # metric
        loss.backward()
        metric.update(labels, outputs)

        # Update the parameters by stepping the trainer; the batch size
        # is required to normalize the gradients by `1 / batch_size`.
        trainer.step(batch_size=inputs.shape[0])

    # Print the evaluation metric and reset it for the next epoch
    name, acc = metric.get()
    print('After epoch {}: {} = {}'.format(epoch + 1, name, acc))
    metric.reset()

In [None]:
metric = mx.metric.Accuracy()
for inputs, labels in val_loader:
    # Possibly copy inputs and labels to the GPU
    inputs = inputs.as_in_context(ctx)
    labels = labels.as_in_context(ctx)
    metric.update(labels, net(inputs))
print('Validaton: {} = {}'.format(*metric.get()))
assert metric.get()[1] > 0.96

In [None]:
def get_mislabeled(loader):
    """Return list of ``(input, pred_lbl, true_lbl)`` for mislabeled samples."""
    mislabeled = []
    for inputs, labels in loader:
        inputs = inputs.as_in_context(ctx)
        labels = labels.as_in_context(ctx)
        outputs = net(inputs)
        # Predicted label is the index is where the output is maximal
        preds = nd.argmax(outputs, axis=1)
        for i, p, l in zip(inputs, preds, labels):
            p, l = int(p.asscalar()), int(l.asscalar())
            if p != l:
                mislabeled.append((i.asnumpy(), p, l))
    return mislabeled

In [None]:
import numpy as np

sample_size = 8
wrong_train = get_mislabeled(train_loader)
wrong_val = get_mislabeled(val_loader)
wrong_train_sample = [wrong_train[i] for i in np.random.randint(0, len(wrong_train), size=sample_size)]
wrong_val_sample = [wrong_val[i] for i in np.random.randint(0, len(wrong_val), size=sample_size)]

import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=sample_size)
for ax, (img, pred, lbl) in zip(axs, wrong_train_sample):
    fig.set_size_inches(18, 4)
    fig.suptitle("Sample of wrong predictions in the training set", fontsize=20)
    ax.imshow(img[0], cmap="gray")
    ax.set_title("Predicted: {}\nActual: {}".format(pred, lbl))
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

fig, axs = plt.subplots(ncols=sample_size)
for ax, (img, pred, lbl) in zip(axs, wrong_val_sample):
    fig.set_size_inches(18, 4)
    fig.suptitle("Sample of wrong predictions in the validation set", fontsize=20)
    ax.imshow(img[0], cmap="gray")
    ax.set_title("Predicted: {}\nActual: {}".format(pred, lbl))
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

In [None]:
conv_layer = nn.Conv2D(kernel_size=(3, 3), channels=32, in_channels=16, activation='relu')
print(conv_layer.params)

In [None]:
lenet = nn.HybridSequential(prefix='LeNet_')
with lenet.name_scope():
    lenet.add(
        nn.Conv2D(channels=20, kernel_size=(5, 5), activation='tanh'),
        nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
        nn.Conv2D(channels=50, kernel_size=(5, 5), activation='tanh'),
        nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
        nn.Flatten(),
        nn.Dense(500, activation='tanh'),
        nn.Dense(10, activation=None),
    )

In [None]:
lenet.initialize(mx.init.Xavier(), ctx=ctx)
lenet.summary(nd.zeros((1, 1, 28, 28), ctx=ctx))

In [None]:
trainer = gluon.Trainer(
    params=lenet.collect_params(),
    optimizer='sgd',
    optimizer_params={'learning_rate': 0.04},
)
metric = mx.metric.Accuracy()
num_epochs = 10

for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        inputs = inputs.as_in_context(ctx)
        labels = labels.as_in_context(ctx)

        with autograd.record():
            outputs = lenet(inputs)
            loss = loss_function(outputs, labels)

        loss.backward()
        metric.update(labels, outputs)

        trainer.step(batch_size=inputs.shape[0])

    name, acc = metric.get()
    print('After epoch {}: {} = {}'.format(epoch + 1, name, acc))
    metric.reset()

for inputs, labels in val_loader:
    inputs = inputs.as_in_context(ctx)
    labels = labels.as_in_context(ctx)
    metric.update(labels, lenet(inputs))
print('Validaton: {} = {}'.format(*metric.get()))
assert metric.get()[1] > 0.985