In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

class RNNNet(nn.Module):
    def __init__(self):
        super(RNNNet, self).__init__()
        self.rnn = nn.RNN(input_size=14*14, hidden_size=64, batch_first=True, nonlinearity='tanh')
        self.fc = nn.Linear(64, 10)
    
    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), 64).to(x.device)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out

def get_data_loaders(batch_size_train=100, batch_size_test=1000, shuffle_train=True):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('.', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=shuffle_train)
    test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False)

    return train_loader, test_loader

def train_model(model, train_loader, test_loader, criterion, optimizer, epochs=5):
    train_log = []
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        
        model.train()
        for x, y in train_loader:
            x = x.view(-1, 4, 14*14)
            optimizer.zero_grad()
            y_hat = model(x)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
        
        train_loss, train_acc = loss_and_accuracy(model, train_loader, criterion)
        test_loss, test_acc = loss_and_accuracy(model, test_loader, criterion)
        train_log.append((epoch, train_loss, train_acc, test_loss, test_acc))
        
        epoch_time = time.time() - start_time
        
        print(f'Epoch {epoch}: Train Loss {train_loss:.4f}, Train Acc {train_acc:.2f}, Test Loss {test_loss:.4f}, Test Acc {test_acc:.2f}')
        print(f'Epoch Time: {epoch_time:.2f} seconds')
    
    return train_log

def loss_and_accuracy(model, data_loader, criterion):
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in data_loader:
            x = x.view(-1, 4, 14*14)
            y_hat = model(x)
            loss += criterion(y_hat, y).item()
            pred = y_hat.argmax(dim=1, keepdim=True)
            correct += pred.eq(y.view_as(pred)).sum().item()
    
    loss /= len(data_loader.dataset)
    accuracy = 100. * correct / len(data_loader.dataset)
    return loss, accuracy

if __name__ == "__main__":
    train_loader, test_loader = get_data_loaders()
    model = RNNNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.015)
    train_log = train_model(model, train_loader, test_loader, criterion, optimizer)

    model.eval()
    x1, y1 = next(iter(train_loader))
    x1 = x1.view(-1, 4, 14*14)
    y1_hat = model(x1)
    print(f'Predicted: {y1_hat.argmax(dim=1)}, Actual: {y1}')


Epoch 1: Train Loss 0.0077, Train Acc 78.18, Test Loss 0.0008, Test Acc 78.76
Epoch Time: 37.62 seconds
Epoch 2: Train Loss 0.0049, Train Acc 86.88, Test Loss 0.0005, Test Acc 87.23
Epoch Time: 46.15 seconds
Epoch 3: Train Loss 0.0037, Train Acc 89.64, Test Loss 0.0004, Test Acc 90.09
Epoch Time: 59.32 seconds
Epoch 4: Train Loss 0.0032, Train Acc 91.23, Test Loss 0.0003, Test Acc 91.52
Epoch Time: 48.34 seconds
Epoch 5: Train Loss 0.0028, Train Acc 92.20, Test Loss 0.0003, Test Acc 92.35
Epoch Time: 42.76 seconds
Predicted: tensor([9, 0, 9, 7, 4, 7, 1, 3, 1, 6, 6, 9, 5, 4, 1, 5, 9, 1, 3, 5, 7, 9, 7, 2,
        8, 7, 7, 2, 7, 2, 1, 7, 0, 4, 4, 0, 6, 1, 6, 2, 2, 3, 5, 9, 7, 3, 9, 6,
        4, 9, 8, 0, 2, 4, 6, 5, 2, 3, 6, 7, 1, 3, 1, 5, 4, 0, 5, 5, 9, 8, 9, 8,
        1, 4, 4, 4, 3, 2, 7, 9, 7, 3, 8, 3, 1, 0, 0, 2, 9, 3, 8, 8, 6, 1, 2, 1,
        3, 0, 3, 8]), Actual: tensor([9, 0, 9, 7, 4, 7, 1, 3, 1, 6, 6, 9, 5, 4, 8, 5, 9, 1, 3, 5, 7, 7, 7, 2,
        8, 7, 7, 2, 7, 2, 1, 7, 0, 7, 4