[View in Colaboratory](https://colab.research.google.com/github/brucecmd/learn_gluon/blob/master/dropout_gluon.ipynb)

In [0]:
from mxnet.gluon import data as gdata
from mxnet.gluon import loss as gloss
from mxnet import nd, autograd, gluon
from mxnet.gluon import nn
from mxnet import init

In [0]:
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)

In [0]:
batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=True)

In [0]:
loss_func = gloss.SoftmaxCrossEntropyLoss()

In [0]:
# 先不管模型结构合不合理，先让模型复杂起来，这样才能看出来是不是dropout有用

# 有时候模型的不合理，会导致模型一直训不出来，一直没法优化。就像上面256下加一个128，再用下面的64，就会出现这种情况。


# 不带dropout的版本
#net = nn.Sequential()
#net.add(nn.Flatten())
#net.add(nn.Dense(256))
#net.add(nn.Dense(256))
#net.add(nn.Dense(64))
#net.add(nn.Dense(10))


# 带dropout的版本
net = nn.Sequential()
net.add(nn.Flatten())
net.add(nn.Dense(256, activation='relu'))
drop_prob1 = 0.5
net.add(nn.Dropout(drop_prob1))
net.add(nn.Dense(256, activation='relu'))
drop_prob2 = 0.3
net.add(nn.Dropout(drop_prob2))
net.add(nn.Dense(64, activation='relu'))
net.add(nn.Dense(10))

net.initialize(init.Normal(sigma=0.01))

In [0]:
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':0.1})

In [0]:
def accuracy(y_hat, y):
    return (y_hat.argmax(axis=1)==y.astype('float32')).mean().asscalar()
  
def estimate_accuracy(data_iter, net):
    total_acc = 0
    for data, label in data_iter:
        y_hat = net(data)
        acc = accuracy(y_hat, label)
        total_acc += acc
    return total_acc / len(data_iter)

In [129]:
epochs = 20
lr = 0.1
batch_size = 256

for i in range(epochs):
    for data, label in train_iter:
        with autograd.record():
            y_hat = net(data)
            l = loss_func(y_hat, label)
        l.backward()
        trainer.step(batch_size)
    train_acc = estimate_accuracy(train_iter, net)
    test_acc = estimate_accuracy(test_iter, net)
    print('epoch[%d], train acc[%f], test acc[%f]'%(i,train_acc,test_acc))

epoch[0], train acc[0.105491], test acc[0.105664]
epoch[1], train acc[0.169886], test acc[0.173633]
epoch[2], train acc[0.219105], test acc[0.220996]
epoch[3], train acc[0.486303], test acc[0.487891]
epoch[4], train acc[0.584852], test acc[0.579395]
epoch[5], train acc[0.750609], test acc[0.756348]
epoch[6], train acc[0.760871], test acc[0.758008]
epoch[7], train acc[0.789478], test acc[0.789062]
epoch[8], train acc[0.801042], test acc[0.795703]
epoch[9], train acc[0.814716], test acc[0.815820]
epoch[10], train acc[0.823543], test acc[0.826074]
epoch[11], train acc[0.840204], test acc[0.834375]
epoch[12], train acc[0.835428], test acc[0.833594]
epoch[13], train acc[0.853059], test acc[0.846680]
epoch[14], train acc[0.859929], test acc[0.859766]
epoch[15], train acc[0.866977], test acc[0.861816]
epoch[16], train acc[0.869642], test acc[0.866211]
epoch[17], train acc[0.872828], test acc[0.869238]
epoch[18], train acc[0.871936], test acc[0.868848]
epoch[19], train acc[0.879266], test acc[