In [None]:
device = 'mps'


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt


In [None]:
train_dataset = torchvision.datasets.QMNIST('data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.QMNIST('data', train=False, download=True, transform=torchvision.transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Training with {len(train_dataset)} samples")
print(f"Testing with {len(test_dataset)} samples")


In [None]:
linear_model = nn.Sequential(
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(linear_model.parameters(), lr=1e-3)

stats = {
    'train_loss': [],
    'test_loss': [],
    'test_accuracy': [],
}


In [None]:
print(f"Model has {sum(p.numel() for p in linear_model.parameters()):,} parameters")
print(linear_model)


In [None]:
for epoch in range(1, 10):
    linear_model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        x = x.view(x.size(0), -1)
        optimizer.zero_grad()
        y_hat = linear_model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        stats['train_loss'].append(loss.item())

    linear_model.eval()
    with torch.no_grad():
        test_loss = 0
        correct = 0
        total = 0
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            x = x.view(x.size(0), -1)
            y_hat = linear_model(x)
            test_loss += criterion(y_hat, y).item()
            _, predicted = y_hat.max(1)
            correct += predicted.eq(y).sum().item()
            total += y.size(0)
        stats['test_loss'].append(test_loss)
        stats['test_accuracy'].append(correct / total)

    print(f"Epoch {epoch}, test loss {test_loss:.4f}, test accuracy {100*correct/total:.2f}%")


In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.plot(stats['train_loss'])
plt.title('Train loss')

plt.subplot(1, 3, 2)
plt.plot(stats['test_loss'])
plt.title('Test loss')

plt.subplot(1, 3, 3)
plt.plot(stats['test_accuracy'])
plt.title('Test accuracy')

plt.show()
