https://www.youtube.com/watch?v=OMDn66kM9Qc&t=2005s

In [7]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader
import pytorch_lightning as pl

model = nn.Sequential(
    nn.Linear(28 * 28, 64), 
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64,10)
)

In [3]:
# define optimizer
optimizer = optim.SGD(model.parameters(), lr=1e-2)

In [4]:
# define loss
loss = nn.CrossEntropyLoss()

In [5]:
# download MNIST and split into training and testing
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

In [6]:
# training and validation loops
nb_epochs = 5
for epoch in range(nb_epochs):
    losses = list()
    for batch in train_loader:
        x, y = batch
        # x: batch x 1 x 28 x 28
        # flatten each image into a vector
        b = x.size(0)
        x = x.view(b,-1)
        
        # forward
        l = model(x) #l: logit
        
        # compute the objective function
        J = loss(l, y)
        
        # cleaning the gradients
        model.zero_grad()
        
        # accumulate the partial derivatives of J wrt params
        J.backward()
        
        # step in the opposite direction of the gradient
        optimizer.step()
        
        losses.append(J.item())
        
    print(f'Epoch {epoch + 1}, train loss: {torch.tensor(losses).mean():.2f}')
        
    losses = list()
    for batch in val_loader:
        x, y = batch
        # x: batch x 1 x 28 x 28
        # flatten each image into a vector
        b = x.size(0)
        x = x.view(b,-1)

        # forward
        with torch.no_grad():
            l = model(x) #l: logit

        # compute the objective function
        J = loss(l, y)

        losses.append(J.item())
    print(f'Epoch {epoch + 1}, train loss: {torch.tensor(losses).mean():.2f}')

Epoch 1, train loss: 1.20
Epoch 1, train loss: 0.48
Epoch 2, train loss: 0.40
Epoch 2, train loss: 0.35
Epoch 3, train loss: 0.32
Epoch 3, train loss: 0.30
Epoch 4, train loss: 0.28
Epoch 4, train loss: 0.27
Epoch 5, train loss: 0.26
Epoch 5, train loss: 0.25
