In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_data(train:bool):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),])
    data = torchvision.datasets.MNIST(
        root="data",
        train=train,
        download=True,
        transform=transform)
    return data

trainset = load_data(train=True)
testset = load_data(train=False)

trainset, testset

(Dataset MNIST
     Number of datapoints: 60000
     Root location: data
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
            ),
 Dataset MNIST
     Number of datapoints: 10000
     Root location: data
     Split: Test
     StandardTransform
 Transform: Compose(
                ToTensor()
            ))

In [3]:
def train(model, dataloader, loss_fn, optimizer):
    model.train()
    n_samples = len(dataloader.dataset)

    for batch_i, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        y_ = model(X)

        loss = loss_fn(y_, y)
        optimizer.zero_grad()  # init gradients to 0
        loss.backward()
        optimizer.step()

        if batch_i % 100 == 0:
            loss = loss.item()
            prog = batch_i * len(X)
            print(f"loss: {loss:>7f}, progress: {prog:>5d}/{n_samples:>5d}")

In [4]:
def test(model, dataloader, loss_fn):
    model.eval()
    n_samples = len(dataloader.dataset)

    total_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            y_ = model(X)

            total_loss += loss_fn(y_, y).item()
            correct += (y_.argmax(1) == y).type(torch.float).sum().item()
    
    accuracy = correct / n_samples
    avg_loss = total_loss / n_samples
    print(f"accuracy: {accuracy:>7f}, avg_loss: {avg_loss:>7f}")

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # input_shape: (?, 1, 28, 28)
        self.conv_stack_1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),  # -> (?, 32, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(2, 2))  # -> (?, 32, 14, 14)
        self.conv_stack_2 = nn.Sequential(
            nn.Conv2d(32, 64, 3),  # -> (?, 64, 12, 12)
            nn.ReLU(),
            nn.MaxPool2d(2, 2))  # -> (?, 64, 6, 6)
        self.fc_stack_1 = nn.Sequential(
            nn.Flatten(),  # -> (?, 2304(=64*6*6))
            nn.Linear(2304, 512),
            nn.Dropout(0.25))
        self.fc_stack_2 = nn.Sequential(
            nn.Linear(512, 128),
            nn.Linear(128, 10))  # classes: 10

    def forward(self, x):
        x = self.conv_stack_1(x)
        x = self.conv_stack_2(x)
        x = self.fc_stack_1(x)
        logits = self.fc_stack_2(x)
        return logits

In [6]:
model = SimpleCNN().to(device)

epochs = 10
batch_size = 64
lr = 0.001  # 1e-3
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

trainset_loader = DataLoader(trainset, batch_size)
testset_loader = DataLoader(testset, batch_size)

for epoch in range(epochs):
    print(f"Epoch: {epoch+1} -----")
    train(model, trainset_loader, loss_fn, optimizer)
    test(model, testset_loader, loss_fn)
print("Done.")

Epoch: 1 -----
loss: 2.307543, progress:     0/60000
loss: 0.161911, progress:  6400/60000
loss: 0.163251, progress: 12800/60000
loss: 0.161326, progress: 19200/60000
loss: 0.079105, progress: 25600/60000
loss: 0.079394, progress: 32000/60000
loss: 0.113373, progress: 38400/60000
loss: 0.113345, progress: 44800/60000
loss: 0.253830, progress: 51200/60000
loss: 0.048867, progress: 57600/60000
accuracy: 0.981400, avg_loss: 0.000913
Epoch: 2 -----
loss: 0.064371, progress:     0/60000
loss: 0.137929, progress:  6400/60000
loss: 0.093271, progress: 12800/60000
loss: 0.030975, progress: 19200/60000
loss: 0.040645, progress: 25600/60000
loss: 0.008967, progress: 32000/60000
loss: 0.069860, progress: 38400/60000
loss: 0.043793, progress: 44800/60000
loss: 0.147654, progress: 51200/60000
loss: 0.049785, progress: 57600/60000
accuracy: 0.986300, avg_loss: 0.000677
Epoch: 3 -----
loss: 0.014115, progress:     0/60000
loss: 0.035507, progress:  6400/60000
loss: 0.014104, progress: 12800/60000
los