## Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm.auto import tqdm

## Set Device

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device

## Create fully connected layer

In [3]:
class RNN(nn.Module):
    def __init__(self, input_size=28, sequence_length=28, hidden_size=256, num_layers=2, num_classes=10):
        super(RNN, self).__init__()
        
        self.input_size = input_size
        self.sequence_length = sequence_length
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size*sequence_length, num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)
        out = out.contiguous().view(out.size(0), -1)
        out = self.fc(out)
        return out

In [4]:
model = RNN()
x = torch.randn(64, 28, 28)

print(model(x).shape)

torch.Size([64, 10])


## Hyperparameters

In [5]:
input_size = 28
sequence_length = 28
hidden_size = 256
num_classes = 10
lr = 0.001
batch_size = 64
num_epochs = 20
grad_accum = 8

In [6]:
train_dataset = datasets.MNIST(root="data/", train=True, transform=transforms.ToTensor(), download=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
train_dataset, train_dataloader

(Dataset MNIST
     Number of datapoints: 60000
     Root location: data/
     Split: Train
     StandardTransform
 Transform: ToTensor(),
 <torch.utils.data.dataloader.DataLoader at 0x19091bf10>)

In [7]:
test_dataset = datasets.MNIST(root="data/", train=False, transform=transforms.ToTensor(), download=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
test_dataset, test_dataloader

(Dataset MNIST
     Number of datapoints: 10000
     Root location: data/
     Split: Test
     StandardTransform
 Transform: ToTensor(),
 <torch.utils.data.dataloader.DataLoader at 0x19091bca0>)

## Initialize Network

In [8]:
model = RNN(input_size=input_size, sequence_length=sequence_length, hidden_size=hidden_size, num_layers=2, num_classes=num_classes)
model = model.to(device)

## Loss and Optimizer

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

## Train the Model

In [10]:
for epoch in range(num_epochs):
    accum = 0
    optimizer.zero_grad()
    model.train()
    for batch in tqdm(train_dataloader, total=len(train_dataloader), desc=f"Epoch {epoch}"):
        x, y = batch
        x = x.squeeze().to(device)
        y = y.to(device)
        y_hat = model(x)

        loss = criterion(y_hat, y) / grad_accum
        loss.backward()

        accum += 1
        if accum % grad_accum == 0:
            optimizer.step()
            optimizer.zero_grad()

    model.eval()
    test_losses = 0
    with torch.no_grad():
        for batch in tqdm(test_dataloader, total=len(test_dataloader), desc=f"Epoch {epoch}"):
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x.squeeze())
    
            loss = criterion(y_hat, y)
            test_losses += loss

    print(f"Epoch {epoch}: Test loss {test_losses/len(test_dataloader):.2f}")

Epoch 0:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 0: Test loss 0.21


Epoch 1:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 1: Test loss 0.11


Epoch 2:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 2: Test loss 0.09


Epoch 3:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 3: Test loss 0.06


Epoch 4:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 4: Test loss 0.07


Epoch 5:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 5: Test loss 0.07


Epoch 6:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 6: Test loss 0.06


Epoch 7:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 7: Test loss 0.05


Epoch 8:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 8: Test loss 0.07


Epoch 9:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 9: Test loss 0.04


Epoch 10:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 10:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 10: Test loss 0.04


Epoch 11:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 11:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 11: Test loss 0.04


Epoch 12:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 12:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 12: Test loss 0.06


Epoch 13:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 13:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 13: Test loss 0.04


Epoch 14:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 14:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 14: Test loss 0.07


Epoch 15:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 15:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 15: Test loss 0.05


Epoch 16:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 16:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 16: Test loss 0.06


Epoch 17:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 17:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 17: Test loss 0.06


Epoch 18:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 18:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 18: Test loss 0.05


Epoch 19:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 19:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 19: Test loss 0.05


In [13]:
def check_accuracy(loader, model):
    model.eval()
    cum_tp = 0
    cum_tot = 0
    with torch.no_grad():
        for batch in tqdm(loader, total=len(loader), desc=f"Epoch {epoch}"):
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x.squeeze())

            preds = F.softmax(y_hat, dim=-1)
            _, pred_idx = preds.max(dim=-1)
            cum_tp += (pred_idx == y).sum()
            cum_tot += y.numel()

    print(f"Accuracy: {cum_tp/cum_tot*100:.4f}")

In [14]:
check_accuracy(train_dataloader, model)
check_accuracy(test_dataloader, model)

Epoch 19:   0%|          | 0/938 [00:00<?, ?it/s]

Accuracy: 99.5317


Epoch 19:   0%|          | 0/157 [00:00<?, ?it/s]

Accuracy: 98.8300
