In [10]:
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [8]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

In [14]:
class Model(nn.Module):   
    def __init__(self, image_shape=(28, 28), n_classes=10):
        super().__init__()
        self.flatten = nn.Flatten()
        self.ffn = nn.Sequential(
            nn.Linear(np.product(image_shape), 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, n_classes)
        )
        
    def forward(self, x):
        x = self.flatten(x)
        return self.ffn(x)

In [33]:
def train(data_loader, model, loss_fn, optimizer):
    size = len(data_loader.dataset)
    for batch, (X, y) in enumerate(data_loader):
        # each step returns 64 images
        
        # prediction
        y_hat = model(X)
        loss = loss_fn(y_hat, y)
        
        # backprop
        # reset param gradients at the start of each iteration
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss_, current = loss.item(), batch * len(X)
            print(f"loss: {loss_:.7f}, [{current}/{size}]")

In [34]:
train_data_loader = DataLoader(training_data, batch_size=64)
model = Model()

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
epochs = 10
for i in range(epochs):
    print(f"Epoch: {i}")
    train(train_data_loader, model, loss_fn, optimizer)

Epoch: 0
loss: 2.3060932, [0/60000]
loss: 2.2994914, [6400/60000]
loss: 2.3031862, [12800/60000]
loss: 2.2807732, [19200/60000]
loss: 2.2948902, [25600/60000]
loss: 2.2802453, [32000/60000]
loss: 2.2823880, [38400/60000]
loss: 2.2778502, [44800/60000]
loss: 2.2637348, [51200/60000]
loss: 2.2625818, [57600/60000]
Epoch: 1
loss: 2.2586756, [0/60000]
loss: 2.2515054, [6400/60000]
loss: 2.2645481, [12800/60000]
loss: 2.2204280, [19200/60000]
loss: 2.2469618, [25600/60000]
loss: 2.2331285, [32000/60000]
loss: 2.2228000, [38400/60000]
loss: 2.2332675, [44800/60000]
loss: 2.2036207, [51200/60000]
loss: 2.1960955, [57600/60000]
Epoch: 2
loss: 2.1932604, [0/60000]
loss: 2.1833849, [6400/60000]
loss: 2.2095344, [12800/60000]
loss: 2.1307061, [19200/60000]
loss: 2.1726534, [25600/60000]
loss: 2.1588840, [32000/60000]
loss: 2.1294587, [38400/60000]
loss: 2.1568196, [44800/60000]
loss: 2.1056566, [51200/60000]
loss: 2.0882933, [57600/60000]
Epoch: 3
loss: 2.0831454, [0/60000]
loss: 2.0679202, [6400