In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.nn.functional import one_hot, cross_entropy
import numpy as np
import time

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('../data', train=False, download=True, transform=transform)

def loader(data, batchsize=1):
    return DataLoader(data, batch_size=batchsize, shuffle=True)

train_loader = loader(train_data, batchsize=100)
test_loader = loader(test_data, batchsize=100)

class RNNModel(nn.Module):
    def __init__(self):
        super(RNNModel, self).__init__()
        self.rnn = nn.RNN(14 * 14, 64, batch_first=True, nonlinearity='tanh')
        self.fc = nn.Linear(64, 10)
    
    def forward(self, x):
        x = x.view(-1, 4, 14 * 14)
        h0 = torch.zeros(1, x.size(0), 64).to(x.device)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out

net = RNNModel()

settings = {
    'eta': 15e-3,
    'epochs': 5,
    'batchsize': 100
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
optimizer = optim.SGD(net.parameters(), lr=settings['eta'])

train_log = []

for epoch in range(1, settings['epochs'] + 1):
    start_time = time.time()
    net.train()
    correct = 0
    total = 0
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = net(x)
        loss = cross_entropy(y_pred, y)
        loss.backward()
        optimizer.step()
        
        correct += (y_pred.argmax(dim=1) == y).sum().item()
        total += y.size(0)
    
    train_acc = 100.0 * correct / total
    epoch_time = time.time() - start_time
    print(f'Epoch: {epoch},  Train Acc: {train_acc:.2f}%, Time: {epoch_time:.2f}s')

Epoch: 1,  Train Acc: 63.90%, Time: 18.36s
Epoch: 2,  Train Acc: 82.87%, Time: 18.86s
Epoch: 3,  Train Acc: 88.40%, Time: 20.24s
Epoch: 4,  Train Acc: 90.57%, Time: 18.30s
Epoch: 5,  Train Acc: 91.81%, Time: 19.01s
