## Download and prepare data

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

## Prepare model

In [3]:
from dataclasses import dataclass

import torch
import torch.optim as optim

from model import ModelArgs, SSMModel

config = ModelArgs()
device = torch.device(config.device)
model = SSMModel(config)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.004)

## Train and test

In [4]:
for epoch in range(10):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28, 28).to(device)  # Reshape data to (batch_size, seq_length, input_size)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)[:, 27, :]
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}'
        )
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.view(-1, 28, 28).to(device)
            target = target.to(device)
            output = model(data)[:, 27, :]
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        print(
            f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n'
        )


Test set: Average loss: 0.0039, Accuracy: 9257/10000 (93%)


Test set: Average loss: 0.0037, Accuracy: 9295/10000 (93%)


Test set: Average loss: 0.0034, Accuracy: 9336/10000 (93%)


Test set: Average loss: 0.0034, Accuracy: 9363/10000 (94%)


Test set: Average loss: 0.0033, Accuracy: 9361/10000 (94%)


Test set: Average loss: 0.0035, Accuracy: 9331/10000 (93%)


Test set: Average loss: 0.0032, Accuracy: 9371/10000 (94%)


Test set: Average loss: 0.0031, Accuracy: 9391/10000 (94%)


Test set: Average loss: 0.0034, Accuracy: 9334/10000 (93%)


Test set: Average loss: 0.0032, Accuracy: 9373/10000 (94%)

