In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

In [2]:
data_train = torchvision.datasets.MNIST(
    root="./data/mnist", 
    train=True, 
    download=True, 
    transform=torchvision.transforms.ToTensor())

data_test = torchvision.datasets.MNIST(
    root="./data/mnist", 
    train=False, 
    download=True, 
    transform=torchvision.transforms.ToTensor())

In [4]:
class SequenceClassifier(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, device='cpu'):
        super(SequenceClassifier, self).__init__()
        self.device = device
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size * input_size, num_classes)
        
    def forward(self, x):
        # Set initial hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)

        # Forward propagate LSTM
        out, _ = self.lstm(
            x, (h0, c0)
        )  # out: tensor of shape (batch_size, seq_length, hidden_size)
        out = out.reshape(out.shape[0], -1)

        # Decode the hidden state of the last time step
        out = self.fc(out)
        return out


In [5]:
device = 'cpu'
num_epochs = 10
batch_size = 64
learning_rate = 10e-4

train_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, shuffle=True, pin_memory=True)
model = SequenceClassifier(28,64,2,10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()

In [6]:
for epoch in range(num_epochs):
    print(f"Epoch #{epoch} training...")
    data_loop = tqdm(train_loader)
    for data in data_loop:
        
        x, y = data
        x = x.to(device).squeeze(1)
        y = y.to(device)    
        
        print(y.shape)    
        
        predictions = model(x)
        loss = criterion(predictions, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        data_loop.set_postfix(loss=loss.item())
            
    total = 0
    num_corrects = 0
    
    for data in tqdm(test_loader):
        
        x, y = data
        x = x.to(device).squeeze(1)
        y = y.float().to(device)
        model.eval()
        
        with torch.no_grad():
        
            predictions = model(x)
            _, predictions = predictions.max(1)
            num_corrects += (predictions == y).float().sum()
            total += predictions.size(0)
        
        model.train()
            
    print(f"Epoch #{epoch} test acc is: {(num_corrects/total)*100}")

Epoch #0 training...


  0%|          | 0/938 [00:00<?, ?it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


KeyboardInterrupt: 