In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

torch.manual_seed(1)

EPOCH = 10
BATCH_SIZE = 64
LR = 0.001
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [2]:
transform = transforms.ToTensor()
train_set = torchvision.datasets.MNIST(root='./mnist/', transform=transform, train=True, download=True)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=32)
test_set = torchvision.datasets.MNIST(root='./mnist/', transform=transform, train=False, download=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=32)
print(train_set.data.size())
print(train_set.targets.size())
print(test_set.data.size())
print(test_set.targets.size())

torch.Size([60000, 28, 28])
torch.Size([60000])
torch.Size([10000, 28, 28])
torch.Size([10000])


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.rnn = nn.LSTM(
            input_size=28,     # image width
            hidden_size=64,
            num_layers=1,
            batch_first=True   # True : (batch, time_step, input_size)   False : (time_step, batch, input_size)
        )
        self.fc = nn.Linear(64, 10)
    
    def forward(self, x):
        r_out, (h_n, h_c) = self.rnn(x, None)
        print('x_size:', x.size())
        print('r_out_size:', r_out.size())
        print('h_n_size:', h_n.size())
        print('h_c_size:', h_c.size())
        # x  (batch, time_step, input_size)    None 初始无输入，后续为记忆内容
        # (h_n, h_c) -> 记忆内容， r_out -> 当前时间输出
        # r_out  (batch, time_step, hidden_size) ，即对所有batch的每个time_step保存对于输入的hidden_size个隐层输出
        # h_n 短期记忆，直接递给下一次输入   h_c 长期记忆，进行储存
        # h_n shape (n_layers, batch, hidden_size)
        # h_c shape (n_layers, batch, hidden_size)
        out = self.fc(r_out[:, -1, :])  # 仅需要最后一个time_step的输出用于全连接的输入
        return out
    
rnn = Net()
print(rnn)

Net(
  (rnn): LSTM(28, 64, batch_first=True)
  (fc): Linear(in_features=64, out_features=10, bias=True)
)


In [4]:
def test(net, device):
    total = 0
    correct = 0
    with torch.no_grad():
        net.to(device)
        for (inputs, labels) in test_loader:
            inputs = inputs.to(device)
            inputs = inputs.view(-1, 28, 28)
            outputs = net(inputs).cpu()
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    print('test accuracy : %.2f %%' % (correct * 100 / total))
    
def train(net, device, EPOCH, LR):
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=LR)
    
    for epoch in range(EPOCH):
        running_loss = 0.0
        for t, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = inputs.view(-1, 28, 28)
            net.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            break
            running_loss += loss.item()
            if t % 200 == 199:
                print('epoch %d step %d train loss : %.3f' % (epoch+1, t+1, running_loss / 200))
                running_loss = 0.0
                test(net, device)

In [5]:
train(rnn, device, EPOCH, LR)

x_size: torch.Size([64, 28, 28])
r_out_size: torch.Size([64, 28, 64])
h_n_size: torch.Size([1, 64, 64])
h_c_size: torch.Size([1, 64, 64])
x_size: torch.Size([64, 28, 28])
r_out_size: torch.Size([64, 28, 64])
h_n_size: torch.Size([1, 64, 64])
h_c_size: torch.Size([1, 64, 64])
x_size: torch.Size([64, 28, 28])
r_out_size: torch.Size([64, 28, 64])
h_n_size: torch.Size([1, 64, 64])
h_c_size: torch.Size([1, 64, 64])
x_size: torch.Size([64, 28, 28])
r_out_size: torch.Size([64, 28, 64])
h_n_size: torch.Size([1, 64, 64])
h_c_size: torch.Size([1, 64, 64])
x_size: torch.Size([64, 28, 28])
r_out_size: torch.Size([64, 28, 64])
h_n_size: torch.Size([1, 64, 64])
h_c_size: torch.Size([1, 64, 64])
x_size: torch.Size([64, 28, 28])
r_out_size: torch.Size([64, 28, 64])
h_n_size: torch.Size([1, 64, 64])
h_c_size: torch.Size([1, 64, 64])
x_size: torch.Size([64, 28, 28])
r_out_size: torch.Size([64, 28, 64])
h_n_size: torch.Size([1, 64, 64])
h_c_size: torch.Size([1, 64, 64])
x_size: torch.Size([64, 28, 28])
r