# Image classification with PyTorch

In this notebook we will:
1. remember PyTorch basics
2. train a simple neural network with PyTorch
3. apply PyTorch Lightning to simplify the training loop

## PyTorch Basics

PyTorch is like Numpy, but with GPU support and automatic gradient computation.

See installation notes: https://pytorch.org/get-started/locally/

We also need `torchvision` installed.

In [None]:
import torch
import torchvision
from tqdm import tqdm
from matplotlib import pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch version is", torch.__version__)
print("Device is", DEVICE)

**Working with tensors**

PyTorch is similar to Numpy.

In [None]:
a = torch.tensor([3, 4, 5])  # Create a new tensor (on CPU).
print(a)

In [None]:
a - 3

In [None]:
a[None]  # Add dimension.

In [None]:
a[None].T  # Transpose.

**Working with devices**

In [None]:
b = a.to(DEVICE)
b

In [None]:
# a + b  # ERROR if DEVICE is cuda.

In [None]:
a.to(DEVICE) + b

In [None]:
a + b.cpu()

## Gradient computation

In [None]:
xs = torch.linspace(0, 5, 100, requires_grad=True)  # Allow gradient computation for this tensor.
ys = torch.sin(xs)
with torch.no_grad():  # Disable computation graph tracking for a while.
    plt.plot(xs, ys)

ys.sum().backward()  # Compute gradients for each parameter.
with torch.no_grad():
    for i in range(0, len(xs), len(xs) // 10):
        grad_value = xs.grad[i]  # Take i-th component of the gradient vector.
        plt.arrow(xs[i], ys[i], 1, grad_value, color="r", head_width=0.05)

# Simple CNN

Each model in PyTorch is an instance of `nn.Module`. The base class provides useful tools:
1. Moving models between devices, e.g. `model.cuda()` and `model.cpu()`
2. Switching model between training and evaluation modes, e.g. `model.train()`, `model.eval()`. It is necessary for the correct BatchNorm and Dropout computation.
3. Iteration across model parameters, e.g. `model.parameters()` and `model.named_parameters()`.
4. Checkpoint creation and loading, e.g. `model.state_dict()` and `model.load_state_dict()`.

`nn.Module` automatically tracks parameters, buffers and submodules. `nn.Parameter` is a tensor wrapper for trainable parameters. Buffers are used to track statistics and can be create by `self.register_buffer()`. Buffers and parameters constitute a model checkpoint. 

The only required method in `nn.Module` is `forward()`, which applies the model.

**Tensor format**

PyTorch stores tensors in the (B, C, H, W) format, where B is a batch size, C is a number of channels, H and W are height and width respectively. Notice, that channels is the second dimension, not the last.

In [None]:
from torch import nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Input is a 32x32 3-channel image, i.e. (B, 3, 32, 32) tensor.
        self.layer1 = nn.Sequential(  # Combines multiple layers in a single block.
            # Disable bias prior to BatchNorm.
            nn.Conv2d(3, 16, 3, bias=False),  # Input: 3 channels, output: 16 channels, kernel size 3.
            nn.BatchNorm2d(16),  # 16 is the number of channels.
            nn.ReLU(inplace=True)  # Inplace operations, when available, reduce memory usage.
        )  # Output is (B, 16, 30, 30). Image size is reduced because there is no padding.
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, bias=False, stride=2),  # Add stride to reduce tensor height and width by a factor of 2.
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )  # Output is (B, 32, 14, 14) because of stride.
        
        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, bias=False, stride=2),  # Add stride to reduce tensor height and width by a factor of 2.
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )  # Output is (B, 64, 6, 6) because of stride.
        
        # Add fully-connected neural head.
        self.head = nn.Sequential(
            # Head input is (B, C, H, W).
            nn.AdaptiveAvgPool2d((1, 1)),  # (B, C, 1, 1).
            nn.Flatten(),  # (B, C).
            nn.Linear(64, 10)  # 10 is the number of classes.
        )  # Output is (B, 10).

    def forward(self, x):
        # Input: (B, 3, 32, 32).
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.head(x)
        # Output: (B, 10).
        return x

**Create dataset and loader**

We will use the standard CIFAR10 dataset from torchvision. CIFAR10 contains small images of 10 classes (ships, airplains etc.). 

PyTorch dataset is an instance of `torch.utils.data.Dataset` with two methods: `__len__` and `__getitem__`. These methods are sufficient for dataset indexing and iteration.

`torch.utils.data.DataLoader` is responsible for the multiprocess data loading and batch collection. It accepts a dataset and loading parameters.

In [None]:
def get_cifar10_loader(part="train", transform=True):
    if part == "train":
        train = True
    elif part == "test":
        train = False
    else:
        raise ValueError(f"Unknown dataset part: {part}")
    # Transform specifies how to create a tensor from an image (usually PIL).
    if transform:
        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
    else:
        transform = None
    dataset = torchvision.datasets.CIFAR10(root="cifar10", train=train, download=True, transform=transform)
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=32,  # The number of images in the batch.
        num_workers=4,  # The number of concurrent readers and preprocessors.
        drop_last=train,  # Drop the truncated last batch during training.
        pin_memory=(DEVICE == "cuda"),  # Optimize CUDA data transfer.
    )

loader = get_cifar10_loader("test")

In [None]:
image, label = get_cifar10_loader("test", transform=False).dataset[10]
print("Example image with label", label)
plt.imshow(image)

In [None]:
x, y = next(iter(loader))
print("Input size:", x.shape)
print("Input range:", x.min(), x.max())
print("Sample labels:", y[:10])

In [None]:
def train_epoch(model, optimizer, loader):
    model.train()  # Switch BatchNorm to the training mode.
    n = 0
    n_correct = 0
    for x, y in tqdm(loader):
        # x: Images with shape (B, 3, 32, 32).
        # y: Labels with shape (B).
        x, y = x.to(DEVICE), y.to(DEVICE)  # Move data to GPU if necessary.
        logits = model(x)  # (B, 10).
        loss = torch.nn.functional.cross_entropy(logits, y)
        optimizer.zero_grad()  # Clean previous gradients.
        loss.backward()  # Compute gradients.
        optimizer.step()  # Update parameters.
        # Update metrics.
        with torch.no_grad():
            predictions = logits.argmax(1)  # (B).
            n_correct += (predictions == y).sum().item()  # item() converts tensor to a Python scalar.
            n += y.numel()
    accuracy = n_correct / n
    print("Train set accuracy:", accuracy)
    return loss.item(), accuracy


def test_epoch(model, loader):
    model.eval()  # Switch BatchNorm to the evaluation mode.
    n = 0
    n_correct = 0
    for x, y in tqdm(loader):
        x, y = x.to(DEVICE), y.to(DEVICE)  # Move data to GPU if necessary.
        logits = model(x)  # (B, 10).
        predictions = logits.argmax(1)  # (B).
        n_correct += (predictions == y).sum().item()  # item() converts tensor to a Python scalar.
        n += y.numel()
    accuracy = n_correct / n
    print("Test set accuracy:", accuracy)
    return accuracy


def train_cifar10(model, optimizer, num_epochs=10):
    model.to(DEVICE)
    train_loader = get_cifar10_loader("train")
    test_loader = get_cifar10_loader("test")
    losses = []
    train_accuracies = []
    test_accuracies = []
    for epoch in range(num_epochs):
        print("EPOCH", epoch)
        final_loss, final_accuracy = train_epoch(model, optimizer, train_loader)
        with torch.no_grad():  # Disable gradient tracking during evaluation.
            test_accuracy = test_epoch(model, test_loader)
        losses.append(final_loss)
        train_accuracies.append(final_accuracy)
        test_accuracies.append(test_accuracy)
    return losses, train_accuracies, test_accuracies

model = SimpleCNN()
# Create Adam optimizer.
optimizer = torch.optim.Adam(
    model.parameters(),  # Pass model parameters.
    lr=0.001  # Learning rate
)
losses, train_accuracies, test_accuracies = train_cifar10(model, optimizer)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
axs[0].plot(losses)
axs[0].set_xlabel("Epoch")
axs[0].set_title("Train loss")
axs[1].plot(train_accuracies, label="Train")
axs[1].plot(test_accuracies, label="Test")
axs[1].legend()
axs[1].set_xlabel("Epoch")
axs[1].set_title("Accuracy")
plt.show()

# Assignment 1
1. Implement a new model with `nn.Flatten` instead of `AdaptiveAvgPool2d`.
2. Adjust the final linear layer's size accordingly.
3. Train the model and compare results.

# Assignment 2
1. Try to use a different optimizer, like SGD.
2. Try to adjust optimizer's parameters to improve test set accuracy.