# Lab SGD


Learning setup based on a [PyTorch tutorial](https://docs.pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

## Boilerplate

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid

data_path = "./data"

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

torch.backends.cudnn.deterministic = True

## Prepare Dataset

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64

trainset = datasets.CIFAR10(
    root="./data", train=True,
    download=True, transform=transform)
TRAINLOADER = DataLoader(
    trainset, batch_size=batch_size,
    shuffle=True, num_workers=4,
    pin_memory=True)

testset = datasets.CIFAR10(
    root="./data", train=False,
    download=True, transform=transform)
TESTLOADER = DataLoader(
    testset, batch_size=batch_size,
    shuffle=False, num_workers=4,
    pin_memory=True)

classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

## Declare Neural Network

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.leaky_relu(self.conv1(x)))
        x = self.pool(F.leaky_relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Check Examples & Test Network

In [None]:
def _imshow(img: torch.Tensor) -> None:
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy(force=True)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def check_examples(net: Net, testloader: DataLoader, n_examples=4) -> None:
    net.eval()
    dataiter = iter(testloader)
    inputs, labels = next(dataiter)
    inputs, labels = inputs[:n_examples].to(DEVICE), labels[:n_examples].to(DEVICE)

    # print images
    _imshow(make_grid(inputs))
    print("GroundTruth: ", ' '.join(f"{classes[label]:5s}" for label in labels))

    outputs = net(inputs)
    _, predictions = torch.max(outputs, 1)

    print("Predicted: ", ' '.join(f"{classes[prediction]:5s}" for prediction in predictions))


def test_network(net: Net, testloader: DataLoader) -> None:
    net.eval()
    # prepare to count predictions for each class
    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}

    # again no gradients needed
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(DEVICE)
            labels = labels.tolist()
            outputs = net(images)
            _, predictions = torch.max(outputs, 1)
            predictions = predictions.tolist()
            # collect the correct predictions for each class
            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1

    # print accuracy for each class
    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")

    correct = sum(correct_pred.values())
    total = sum(total_pred.values())
    print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")

## Manual training with basic SGD

In [None]:
def train_sgd_manual(net: Net, n_epochs: int, lr: float) -> None:
    net.train()
    params = list(net.parameters())

    for epoch in range(n_epochs):
        running_loss = 0.0

        for data in TRAINLOADER:
            inputs, labels = data
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = net(inputs)
            loss = F.cross_entropy(outputs, labels)

            grads = torch.autograd.grad(loss, params)

            with torch.no_grad():
                for param, grad in zip(params, grads):
                    param -= lr * grad

            running_loss += loss.item()

        print(f"epoch {epoch + 1:2d} loss: {running_loss / 2000:.4f}")

In [None]:
torch.manual_seed(26)
net = Net().to(DEVICE)
train_sgd_manual(net, n_epochs=20, lr=1.0)
test_network(net, TESTLOADER)

**Above learning rate is too high, leading to invalid loss values**

In [None]:
torch.manual_seed(26)
net = Net().to(DEVICE)
train_sgd_manual(net, n_epochs=20, lr=0.001)
test_network(net, TESTLOADER)

**Above learning rate is quite low, leading to little learning progress**

In [None]:
torch.manual_seed(26)
net = Net().to(DEVICE)
train_sgd_manual(net, n_epochs=20, lr=0.01)
test_network(net, TESTLOADER)

## Manual training with added momentum

In [None]:
def train_momentum_manual(
        net: Net, n_epochs: int, lr: float, momentum: float) -> None:
    net.train()
    params = list(net.parameters())

    velocities = [torch.zeros_like(p) for p in params]

    for epoch in range(n_epochs):
        running_loss = 0.0

        for data in TRAINLOADER:
            inputs, labels = data
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = net(inputs)
            loss = F.cross_entropy(outputs, labels)

            grads = torch.autograd.grad(loss, params)

            with torch.no_grad():
                for param, grad, velocity in zip(params, grads, velocities):
                    velocity.mul_(momentum).add_(grad)
                    param -= lr * velocity

            running_loss += loss.item()

        print(f"epoch {epoch + 1:2d} loss: {running_loss / 2000:.4f}")

**When using momentum a lower base learning rate may be needed, but in this case we can keep it equal, allowing to see the change just from momentum**

In [None]:
torch.manual_seed(26)
net = Net().to(DEVICE)
train_momentum_manual(net, n_epochs=20, lr=0.01, momentum=0.9)
check_examples(net, TESTLOADER)
#
test_network(net, TESTLOADER)

## Standard training with a torch optimizer

In [None]:
def train_opt(
        net: Net, n_epochs: int, optimizer: torch.optim.Optimizer) -> None:
    net.train()

    for epoch in range(n_epochs):  # loop over the dataset multiple times
        running_loss = 0.0

        for data in TRAINLOADER:
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

        print(f"epoch {epoch + 1:2d} loss: {running_loss / 2000:.4f}")

In [None]:
torch.manual_seed(26)
net = Net().to(DEVICE)
train_opt(net, 20, torch.optim.SGD(net.parameters(), lr=0.01))
test_network(net, TESTLOADER)

In [None]:
torch.manual_seed(26)
net = Net().to(DEVICE)
train_opt(net, 20, torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9))
test_network(net, TESTLOADER)

In [None]:
torch.manual_seed(26)
net = Net().to(DEVICE)
train_opt(net, 20, torch.optim.AdamW(net.parameters(), lr=0.001))
test_network(net, TESTLOADER)

**AdamW needs a lower learning rate than SGD with momentum, but also gets a better loss (and slightly better accuracy)**

**AdamW often has a stronger advantage also in testing for other (more complex) network types**