In [1]:
import torch
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms as tfs
from torchvision.datasets import MNIST

In [2]:
# 定义数据
data_tf = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5], [0.5]) # 标准化
])

train_set = MNIST('./data', train=True, transform=data_tf, download=True)
test_set = MNIST('./data', train=False, transform=data_tf, download=True)

train_data = DataLoader(train_set, 64, True)
test_data = DataLoader(test_set, 128, False)

In [3]:
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(SimpleLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, 1, batch_first=True)
        self.fc = nn.Linear(hidden_size, 10)
    
    def forward(self, x):
        out, (ht, c) = self.lstm(x)
        ht = ht.squeeze()
        y_pred = self.fc(ht)
        return y_pred

In [3]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(SimpleRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, 1, batch_first=True)
        self.fc = nn.Linear(hidden_size, 10)
    
    def forward(self, x):
        out, ht = self.rnn(x)
        ht = ht.squeeze()
        y_pred = self.fc(ht)
        return y_pred

In [5]:
lstm = SimpleLSTM(28, 100)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(lstm.parameters(), lr=0.1)

In [7]:
# 训练模型
Epochs = 5
for epoch in range(Epochs):
    loss_sum = 0
    acc_sum = 0
    for image, label in train_data:
        image = Variable(image.squeeze())
        label = Variable(label)
        # 前向传播
        y_pred = lstm(image)
        loss = criterion(y_pred, label)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_sum += loss.detach().item()
        _, out = y_pred.max(1)                     # 返回每一行最大值对应的下标，就是图片的预测值
        num_correct = (out == label).sum().item()  # 统计预测正确的数量
        acc = num_correct / image.shape[0]         # 得到这一个batch的平均准确率
        acc_sum += acc
    ave_train_loss = loss_sum / len(train_data)
    ave_train_acc = acc_sum / len(train_data)
    loss_sum = 0
    acc_sum = 0
    # 在测试集上检验效果
    lstm.eval()  # 将模型改为预测模式，eval（）时，pytorch会自动把BN和DropOut固定住，不会取平均，而是用训练好的值。
    for image, label in test_data:
        image = Variable(image.squeeze())
        label = Variable(label)
        y_pred = lstm(image)
        loss = criterion(y_pred, label)
        loss_sum += loss.detach().item()
        _, out = y_pred.max(1)
        num_correct = (out == label).sum().item()
        acc = num_correct / image.shape[0]
        acc_sum += acc
    ave_test_loss = loss_sum / len(test_data)
    ave_test_acc = acc_sum / len(test_data)
    print('epoch: {:2d}, train loss: {:.4f}, train acc: {:.4f}, test loss: {:.4f}, test acc: {:.4f}'.format(epoch + 1, ave_train_loss, ave_train_acc, ave_test_loss, ave_test_acc))

epoch:  1, train loss: 0.2080, train acc: 0.9389, test loss: 0.1797, test acc: 0.9418
epoch:  2, train loss: 0.1307, train acc: 0.9599, test loss: 0.1020, test acc: 0.9699
epoch:  3, train loss: 0.1007, train acc: 0.9695, test loss: 0.0937, test acc: 0.9731
epoch:  4, train loss: 0.0830, train acc: 0.9755, test loss: 0.1038, test acc: 0.9700
epoch:  5, train loss: 0.0681, train acc: 0.9794, test loss: 0.0920, test acc: 0.9744
