[View in Colaboratory](https://colab.research.google.com/github/brucecmd/learn_gluon/blob/master/softmax_classification_gluon.ipynb)

In [0]:
from mxnet import nd, autograd
from mxnet.gluon import data as gdata
from mxnet.gluon import loss as gloss
from mxnet import gluon

In [0]:
# load data from gluon.data
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)

In [0]:
# prepare data iter
batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=True)

In [0]:
from mxnet.gluon import nn


In [0]:
# define net structure
net = nn.Sequential()
net.add(nn.Flatten())
net.add(nn.Dense(10)) # 10 outputs. You can ignore the inputs num

In [0]:
# initial the parameter
from mxnet import init
net.initialize(init.Normal(sigma=0.01))

In [0]:
# define the loss function
loss_func = gloss.SoftmaxCrossEntropyLoss()

In [0]:
# define the trainer
batch_size = 256
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':0.01})

In [0]:
# define accuracy func
def accuracy(y, y_hat):
    return (y_hat.argmax(axis=1)==y.astype('float32')).mean().asscalar()

# estimate accuracy
def estimate_accuracy(data_iter, net):
    acc = 0
    for data, label in data_iter:
        acc += accuracy(label, net(data))
    return acc / len(data_iter)

In [123]:
# now train and test
epochs = 10
for i in range(epochs):
    for feature, label in train_iter:
        with autograd.record():
            y_hat = net(feature)
            l = loss_func(y_hat, label) # attention!!! The first argument is y_hat, the second is label, do not reverse it.
        l.backward()
        trainer.step(batch_size)
    train_acc = estimate_accuracy(train_iter, net)
    test_acc = estimate_accuracy(test_iter, net)
    print('epoch [%d, train accuracy[%f], test accuracy[%f]'%(i, train_acc, test_acc))

epoch [0, train accuracy[0.688348], test accuracy[0.690527]
epoch [1, train accuracy[0.739744], test accuracy[0.737988]
epoch [2, train accuracy[0.759469], test accuracy[0.760156]
epoch [3, train accuracy[0.770872], test accuracy[0.768750]
epoch [4, train accuracy[0.782691], test accuracy[0.784668]
epoch [5, train accuracy[0.790653], test accuracy[0.791406]
epoch [6, train accuracy[0.795191], test accuracy[0.793750]
epoch [7, train accuracy[0.799767], test accuracy[0.800195]
epoch [8, train accuracy[0.803341], test accuracy[0.808203]
epoch [9, train accuracy[0.806521], test accuracy[0.809961]
