# PyTorch Lightning

Lightning is a high-level library to work with PyTorch. It reduces the amount of code and allows easy multi-gpu training, mixed-precision training, checkpointing, logging, and many other useful tools.

The two main abstractions in Lightning are `LightningModule`, `LightningDataModule` and `Trainer`. The user must specify the first two and use Trainer for training and evaluation.

`LightningModule` is an instance of `torch.nn.Module` with additional methods like `configure_optimizers`, `training_step` and `validation_step`.

`LightningDataModule` creates dataloaders for training, validation and test parts of the dataset.

`Trainer` is usually applied without modification.

In this notebook we will reproduce results from the first part by using PyTorch Lightning.

In [None]:
# Uncomment to install PyTorch Lightning.
# ! pip install pytorch_lightning

In [None]:
import pytorch_lightning as pl
import torch
import torchvision

from matplotlib import pyplot as plt

In [None]:
class Data(pl.LightningDataModule):
    def __init__(self, num_workers=4, batch_size=32):
        super().__init__()
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        
    def train_dataloader(self):
        dataset = torchvision.datasets.CIFAR10(root="cifar10", train=True, download=True, transform=self.transform)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,  # The number of images in the batch.
            num_workers=self.num_workers,  # The number of concurrent readers and preprocessors.
            drop_last=True,  # Drop the truncated last batch during training.
            pin_memory=torch.cuda.is_available(),  # Optimize CUDA data transfer.
        )

    def test_dataloader(self):
        dataset = torchvision.datasets.CIFAR10(root="cifar10", train=False, download=True, transform=self.transform)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,  # The number of images in the batch.
            num_workers=self.num_workers,  # The number of concurrent readers and preprocessors.
            pin_memory=torch.cuda.is_available(),  # Optimize CUDA data transfer.
        )

data_module = Data()
x, y = next(iter(data_module.test_dataloader()))  # Test loader.

In [None]:
from torch import nn

class Module(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # The same model as in Part 1.
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, bias=False, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, bias=False, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        # The same as in Part 1.
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.head(x)
        return x

    # Optimizer is now defined in Module.
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    # Process batch and compute loss.
    def training_step(self, batch):
        # No need to call self.train(). Lightning do it.
        x, y = batch  # No need to move to GPU. Lightning do it.
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        self.log("train/loss", loss, prog_bar=True)
        with torch.no_grad():
            predictions = logits.argmax(1)  # (B).
            correct = (predictions == y).sum().item()
            accuracy = correct / y.numel()
            self.log("train/accuracy", accuracy)
        return loss  # No need to manually do optimizer step.

    # Log test metrics for each batch.
    def test_step(self, batch):
        # No need to call self.eval(). Lightning do it.
        x, y = batch  # No need to move to GPU. Lightning do it.
        logits = self(x)  # (B, 10).
        predictions = logits.argmax(1)  # (B).
        correct = (predictions == y).sum().item()
        accuracy = correct / y.numel()
        # Compute accuracy for each batch and compute average at the epochs end.
        # This variant is not very accurate, as the last batch can be smaller than the rest.
        self.log("test/accuracy", accuracy, on_epoch=True)

model = Module()
data = Data()
trainer = pl.Trainer(max_epochs=10)  # Trainer can choose devices automatically.
trainer.fit(model, data)

In [None]:
trainer.test(model, data)

**Visualization.** PyTorch Lightning dumps logs for TensorBoard by default. We can visualize it inplace.

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

# Assignment

1. Implement `validation_step` to log test set metrics after each epoch.
2. (Advanced) Add learning rate scheduler to `configure_optimizers` and improve the results. In this case the method must return two values.