In [22]:
import torch
import numpy as np

In [23]:
def dropout(x, prob):
    x = x.float()
    assert 0 <= prob <= 1
    if 1 == prob:
        return torch.zeros_like(x)
    mask = (torch.rand(x.shape) < (1 - prob)).float()
    return mask * x / (1 - prob)

In [24]:
x = torch.arange(16).view(2, 8)
print(dropout(x, 0))
print(dropout(x, 0.5))
print(dropout(x, 1))

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  2.,  4.,  6.,  0., 10., 12., 14.],
        [16., 18., 20., 22.,  0.,  0., 28.,  0.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


In [25]:
# 定义模型参数
num_inputs, num_hidden1, num_hidden2, num_outputs = 784, 256, 256, 10
w1 = torch.tensor(
    np.random.normal(0, 0.01, size=(num_inputs, num_hidden1)),
    dtype=torch.float32,
    requires_grad=True
)
b1 = torch.zeros(num_hidden1, requires_grad=True)
w2 = torch.tensor(
    np.random.normal(0, 0.01, size=(num_hidden1, num_hidden2)),
    dtype=torch.float32,
    requires_grad=True
)
b2 = torch.zeros(num_hidden2, requires_grad=True)
w3 = torch.tensor(
    np.random.normal(0, 0.01, size=(num_hidden2, num_outputs)),
    dtype=torch.float32,
    requires_grad=True
)
b3 = torch.zeros(num_outputs, requires_grad=True)
params = [w1, b1, w2, b2, w3, b3]

In [26]:
# 定义模型
prob1, prob2 = 0.2, 0.5
def net(x, is_training = True):
    x = x.view(-1, num_inputs)
    H1 = (torch.matmul(x, w1) + b1).relu()
    if is_training:
        H1 = dropout(H1, prob1)
    
    H2 = (torch.matmul(H1, w2) + b2).relu()
    if is_training:
        H2 = dropout(H2, prob2)
    return (torch.matmul(H2, w3) + b3)

In [27]:
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for x, y in data_iter:
        n += y.shape[0]
        if isinstance(net, torch.nn.Module):
            net.eval()      # 评估模式，这会关闭dropout
            acc_sum += (net(x).argmax(dim = 1) == y).float().sum().item()
            net.train()     # 改回训练模式
            continue
        if ('is_training' in net.__code__.co_varnames):
            acc_sum += (net(x, is_training = False).argmax(dim = 1) == y).float().sum().item()
            continue
        acc_sum += (net(x).argmax(dim = 1) == y).float().sum().item()
    return acc_sum / n

In [28]:
import torchvision
# 下载数据集
mnist_train = torchvision.datasets.FashionMNIST(
    root = './data/FashionMNIST',
    train=True, download=True,
    transform=torchvision.transforms.ToTensor()     # 自动转为torch张量
)
mnist_test = torchvision.datasets.FashionMNIST(
    root = './data/FashionMNIST',
    train=False, download=True,
    transform=torchvision.transforms.ToTensor()     # 自动转为torch张量
)
train_iter = torch.utils.data.DataLoader(
    mnist_train, batch_size=256,
    shuffle=True,
    num_workers=0       # 开启num_workers个线程
)
test_iter = torch.utils.data.DataLoader(
    mnist_test, batch_size=256,
    shuffle=True,
    num_workers=0       # 开启num_workers个线程
)
def sgd(params, lr, batch_size):
    for param in params:
        param.data -= param.grad * lr / batch_size
def train(net, train_iter, test_iter, loss, epochs, batch_size, params, lr):
    for epoch in range(epochs):
        train_loss_sum, train_acc_sum, n = 0.0, 0.0, 0
        for x, y in train_iter:
            y_hat = net(x)
            l = loss(y_hat, y).sum()

            l.backward()
            sgd(params, lr, batch_size)

            for param in params:
                param.grad.data.zero_()
            
            train_loss_sum += l.sum().item()
            train_acc_sum += (y_hat.argmax(dim = 1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f' %
                (epoch+1, train_loss_sum/n, train_acc_sum/n, test_acc))



In [29]:
# 训练
num_epochs, lr, batch_size = 5, 100, 256
loss = torch.nn.CrossEntropyLoss()
train(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

epoch 1, loss 0.0045, train acc 0.549, test acc 0.736
epoch 2, loss 0.0023, train acc 0.784, test acc 0.771
epoch 3, loss 0.0019, train acc 0.821, test acc 0.777
epoch 4, loss 0.0018, train acc 0.837, test acc 0.794
epoch 5, loss 0.0016, train acc 0.848, test acc 0.838
