In [32]:
from mxnet import npx, np, autograd
from mxnet.gluon import nn, Trainer
from mxnet.gluon.loss import SoftmaxCrossEntropyLoss
import mxnet.initializer as init
from d2l import mxnet as d2l
npx.set_np()

## Load Data

In [30]:
num_classes = 10
batch_size = 200
train_dl, test_dl = d2l.load_data_fashion_mnist(batch_size)

In [6]:
abc = next(iter(train_dl))

In [10]:
abc[1].shape

(2, 1, 28, 28)

## Define Architecture

In [22]:
model = nn.Sequential()
layers = [nn.Dense(256, 'relu'),
          nn.Dropout(0.1),
          nn.Dense(256, 'relu'),
          nn.Dropout(0.3),
          nn.Dense(num_classes)]
for layer in layers:
    model.add(layer)
model.initialize(init.Normal(sigma=0.01))

## Define Loss

In [None]:
def loss_fn(yhat, y):
    loss_cls = SoftmaxCrossEntropyLoss()
    return loss_cls(yhat, y)

## Define Trainer

In [44]:
num_epochs = 3
lr = 0.1
optimizer = Trainer(model.collect_params(), 'sgd', {'learning_rate': lr})
for epoch in range(num_epochs):
    for X, y in iter(train_dl):
        batch_size = X.shape[0]
        with autograd.record():
            yhat = model(X)
            loss = loss_fn(yhat, y)
        loss.backward()
        optimizer.step(batch_size)
    print(f'Epoch: {epoch}, Loss: {loss.mean()}')

Epoch: 0, Loss: 0.46718895
Epoch: 1, Loss: 0.35136956
Epoch: 2, Loss: 0.41938508


## Get Accuracy

In [38]:
def softmax(output_activations):
    exponentiated_activations = np.exp(output_activations)
    partition_function = exponentiated_activations.sum(axis=1, keepdims=True)
    return exponentiated_activations/partition_function

In [39]:
def get_labels_from_softmax(probs):
    return np.argmax(probs, axis=1)

In [42]:
def get_accuracy(dl, model):
    num_samples = 0
    num_correct = 0
    for X, y in iter(dl):
        logit_yhat = model(X)
        yhat = softmax(logit_yhat)
        preds = get_labels_from_softmax(yhat)
        num_correct += (preds.astype('int32') == y).sum()
        num_samples += X.shape[0]
    return num_correct/num_samples

In [45]:
get_accuracy(test_dl, model)

array(0.8651)