In [8]:
import matplotlib.pyplot as plt
from mxnet import gluon, init, nd, autograd
from mxnet.gluon import nn
from mxnet.gluon.data.vision import transforms

In [4]:
def get_dataloader(batch_size, num_workers):
    transformer = transforms.Compose([
        transforms.ToTensor()
    ])
    train = gluon.data.vision.datasets.FashionMNIST(train=True)
    train = train.transform_first(transformer)
    test = gluon.data.vision.datasets.FashionMNIST(train=False)
    test = test.transform_first(transformer)
    train_iter = gluon.data.DataLoader(train, batch_size, shuffle=True, num_workers=num_workers)
    test_iter = gluon.data.DataLoader(test, batch_size, shuffle=False, num_workers=num_workers)
    return train_iter, test_iter

In [12]:
def softmax(X):
    X_exp = X.exp()
    normalization_constant = X_exp.sum(axis=1, keepdims=True)
    return X_exp / normalization_constant

In [19]:
def evaluate_acc(net, data_iter):
    accumulator = 0
    size = 0
    for X, y in data_iter:
        y_hat = net(X)
        y_hat = softmax(y_hat)
        accumulator += (y_hat.argmax(axis=1)==y.astype('float32')).sum()
        size += len(y)
    return accumulator / size

In [20]:
batch_size = 256
num_workers = 4
train_iter, test_iter = get_dataloader(batch_size, num_workers)

In [21]:
net = nn.Sequential()
net.add(nn.Dense(10))
net.initialize(init.Normal(0.01))

In [22]:
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
loss = gluon.loss.SoftmaxCrossEntropyLoss()

In [23]:
epochs = 5
for epoch in range(epochs):
    for X, y in train_iter:
        with autograd.record():
            y_hat = net(X)
            l = loss(y_hat, y)
        l.backward()
        trainer.step(batch_size)
    epoch_acc = evaluate_acc(net, train_iter)
    print("Epoch %d, acc: %f" % (epoch, epoch_acc.asscalar()))

Epoch 0, acc: 0.805600
Epoch 1, acc: 0.822417
Epoch 2, acc: 0.826817
Epoch 3, acc: 0.831850
Epoch 4, acc: 0.838033


In [6]:
for X, y in train_iter:
    print(X.shape)
    print(y.shape)
    break

(10, 1, 28, 28)
(10,)
