In [24]:
from torch import torch
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt


Load Data

In [None]:
bs = 64
lr = 0.2
epochs = 30

In [None]:
from pathlib import Path

data_dirs = [Path("./"), Path("../")]

tf = transforms.Compose([
    # 0.1307 is the mean of the MNIST dataset, 0.3081 is the standard deviation
    # use flatten(1) to flatten the shape (1, 28, 28) to (784)
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)), 
    transforms.Lambda(lambda x: x.flatten(0))
])

for data_dir in data_dirs:
    if (data_dir / "MNIST").exists():
        train_data = datasets.MNIST(data_dir, train=True, transform=tf)
        test_data = datasets.MNIST(data_dir, train=False, transform=tf)
        break
else:
    train_data = datasets.MNIST("./", train=True, download=True, transform=tf)
    test_data = datasets.MNIST("./", train=False, download=True, transform=tf)
    
g = torch.Generator().manual_seed(42)

train_data, val_data = random_split(train_data, [50000, 10000], generator=g)

print(train_data[0][0].shape)

train_loader = DataLoader(train_data, batch_size=bs, shuffle=True)
val_loader = DataLoader(val_data, batch_size=bs*2, shuffle=True)
test_loader = DataLoader(test_data, batch_size=bs*2)

len(train_data), len(val_data), len(test_data)

Model


In [104]:
import numpy as np
from torch import Tensor

loss_func = F.cross_entropy

def accuracy(input:Tensor, target:Tensor):
    preds = torch.argmax(input, dim=1)
    return (preds == target).float().mean()

In [None]:
class Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10)
        
    def forward(self, x):
        return self.lin(x)
    
    def evaluate(self, loader:DataLoader): 
        """
        Evaluate the model on the given data
        return: (loss, accuracy)
        """
        self.eval()
        total_loss = 0
        total_acc = 0
        with torch.no_grad():
            for x_batch, y_batch in loader:
                preds = self.forward(x_batch)
                total_acc += accuracy(preds, y_batch)
                total_loss += loss_func(preds, y_batch)
        return total_loss / len(loader), total_acc / len(loader)


In [None]:

def get_model():
    model = Logistic()
    return model, optim.SGD(model.parameters(), lr=lr)

model, optimizer = get_model()
x_batch, y_batch = next(iter(train_loader))
print(x_batch.shape, y_batch.shape)
preds = model.forward(x_batch)
print(preds.shape, y_batch.shape)
loss = loss_func(preds, y_batch)
print(loss)
print(accuracy(preds, y_batch))

Train

In [None]:

for epoch in range(epochs):
    model.train()
    for x_batch, y_batch in train_loader:
        preds = model.forward(x_batch)
        loss = loss_func(preds, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    loss, acc = model.evaluate(val_loader)
    print(f"epoch {epoch+1} loss: {loss:.2f}, accuracy: {acc:.2f}")

Test

In [111]:
loss, acc = model.evaluate(test_loader)
print(f"loss: {loss:.2f}, accuracy: {acc:.2f}")

loss: 2.37, accuracy: 0.12
