In [24]:
from torch import torch
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt


Load Data

In [None]:
bs = 64
lr = 0.2
epochs = 30

In [None]:
from pathlib import Path

data_dirs = [Path("./"), Path("../")]

tf = transforms.Compose([
    # 0.1307 is the mean of the MNIST dataset, 0.3081 is the standard deviation
    # use flatten(1) to flatten the shape (1, 28, 28) to (784)
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)), 
    transforms.Lambda(lambda x: x.flatten(0))
])

for data_dir in data_dirs:
    if (data_dir / "MNIST").exists():
        train_data = datasets.MNIST(data_dir, train=True, transform=tf)
        test_data = datasets.MNIST(data_dir, train=False, transform=tf)
        break
else:
    train_data = datasets.MNIST("./", train=True, download=True, transform=tf)
    test_data = datasets.MNIST("./", train=False, download=True, transform=tf)
    
g = torch.Generator().manual_seed(42)

train_data, val_data = random_split(train_data, [50000, 10000], generator=g)

print(train_data[0][0].shape)

train_loader = DataLoader(train_data, batch_size=bs, shuffle=True)
val_loader = DataLoader(val_data, batch_size=bs*2, shuffle=True)
test_loader = DataLoader(test_data, batch_size=bs*2)

len(train_data), len(val_data), len(test_data)

torch.Size([784])


(50000, 10000, 10000)

Model


In [76]:
loss_func = F.cross_entropy

def accuracy(input, target):
    preds = torch.argmax(input, dim=1)
    return (preds == target).float().mean()

In [77]:
class Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10)
        
    def forward(self, x):
        return self.lin(x)


In [None]:

def get_model():
    model = Logistic()
    return model, optim.SGD(model.parameters(), lr=lr)

model, optimizer = get_model()
x_batch, y_batch = next(iter(train_loader))
print(x_batch.shape, y_batch.shape)
preds = model.forward(x_batch)
print(preds.shape, y_batch.shape)
loss = loss_func(preds, y_batch)
print(loss)
print(accuracy(preds, y_batch))

torch.Size([64, 784]) torch.Size([64])
torch.Size([64, 10]) torch.Size([64])
tensor(2.4183, grad_fn=<NllLossBackward0>)
tensor(0.1094)


Train

In [92]:

for epoch in range(epochs):
    model.train()
    for x_batch, y_batch in train_loader:
        preds = model.forward(x_batch)
        loss = loss_func(preds, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    model.eval()
    with torch.no_grad():
        total_loss = 0
        total_acc = 0
        for x_batch, y_batch in val_loader:
            preds = model.forward(x_batch)
            loss = loss_func(preds, y_batch)
            total_loss += loss.item()
            total_acc += accuracy(preds, y_batch).item()
        print(f"epoch {epoch+1} loss: {total_loss/len(val_loader):.2f}, accuracy: {total_acc/len(val_loader):.2f}")


epoch 1 loss: 0.99, accuracy: 0.75
epoch 2 loss: 0.51, accuracy: 0.87
epoch 3 loss: 0.41, accuracy: 0.90
epoch 4 loss: 0.82, accuracy: 0.83
epoch 5 loss: 0.40, accuracy: 0.90
epoch 6 loss: 0.45, accuracy: 0.89
epoch 7 loss: 1.17, accuracy: 0.80
epoch 8 loss: 0.49, accuracy: 0.88
epoch 9 loss: 0.42, accuracy: 0.90
epoch 10 loss: 0.51, accuracy: 0.88
epoch 11 loss: 0.42, accuracy: 0.90
epoch 12 loss: 0.46, accuracy: 0.89
epoch 13 loss: 0.51, accuracy: 0.88
epoch 14 loss: 0.49, accuracy: 0.88
epoch 15 loss: 0.45, accuracy: 0.89
epoch 16 loss: 0.52, accuracy: 0.88
epoch 17 loss: 0.46, accuracy: 0.89
epoch 18 loss: 0.74, accuracy: 0.82
epoch 19 loss: 0.79, accuracy: 0.82
epoch 20 loss: 0.45, accuracy: 0.89
epoch 21 loss: 0.58, accuracy: 0.87
epoch 22 loss: 0.57, accuracy: 0.86
epoch 23 loss: 0.52, accuracy: 0.88
epoch 24 loss: 0.44, accuracy: 0.90
epoch 25 loss: 0.49, accuracy: 0.88
epoch 26 loss: 0.61, accuracy: 0.87
epoch 27 loss: 0.43, accuracy: 0.90
epoch 28 loss: 0.84, accuracy: 0.80
e