Copyright (c) 2023 Graphcore Ltd. All rights reserved.

This Notebook takes you through the simple steps to use Lightning (https://github.com/Lightning-AI/lightning/) to run on the IPU. It creates a simple model and wraps the model in Lightning's standard processes for defining training, validation and optimiser behaviour. The only IPU-specific code is the dataloader and the instructions telling Lightning to run on the IPU.

This Notebook assumes you are running in a Docker container which needs to be updated to include all the required Linux packages.

The code in this Notebook shares requirements and dependencies with the adjacent PyTorch models. Install all requirements from the pytorch directory.

This Notebook assumes you are running in a Docker container which needs to be updated to include all the required Linux packages.

In [None]:
import os

import_location = os.getenv("POPTORCH_CNN_IMPORTS", "../pytorch")
number_of_ipus = 4
dataset_directory = os.getenv("DATASET_DIR", "fashionmnist_data/")

In [None]:
%cd ../../

If you are running this Notebook in a docker container you will have the sudo rights to execute the cell below. If not, you may need to execute this separately.

In [None]:
!apt update
!apt-get install -y $(< {import_location}/required_apt_packages.txt)

Install PyTorch requirements

In [None]:
!make install -C {import_location}
!make install-turbojpeg -C {import_location}

In [None]:
import torch
import torchvision
import poptorch
import pytorch_lightning as pl
from pytorch_lightning.strategies import IPUStrategy
from torch import nn
import argparse

This notebook runs a small model that will be defined below.
Start by writing a basic block with residual connections which will be used in the simple model later.

In [None]:
class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, stride_inp=(1, 1), downsample_key=False
    ):
        super().__init__()
        self.block = torch.nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                3,
                stride=stride_inp,
                padding=(1, 1),
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
            nn.Conv2d(out_channels, out_channels, 3, padding=(1, 1), bias=False),
            nn.BatchNorm2d(out_channels),
        )

        self.downsample = None
        if downsample_key:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride_inp, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        out = self.block(x)
        if self.downsample:
            out += self.downsample(x)
        return out

Now define a small Resnet model which also outputs 10 classes. It uses the residual block defined above. This defines the structure of the model and what a forward pass looks like.

In [None]:
class ResNetFromScratch(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.model = torch.nn.Sequential(
            nn.Conv2d(1, 64, 7),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, 2, 1),
            nn.Sequential(ResidualBlock(64, 64), ResidualBlock(64, 64)),
            nn.Sequential(
                ResidualBlock(64, 128, (2, 2), True), ResidualBlock(128, 128)
            ),
            nn.Sequential(
                ResidualBlock(128, 256, (2, 2), True), ResidualBlock(256, 256)
            ),
            nn.Sequential(
                ResidualBlock(256, 512, (2, 2), True), ResidualBlock(512, 512)
            ),
            nn.AdaptiveAvgPool2d((1, 1)),
            torch.nn.Flatten(),
            nn.Linear(512, 10),
            nn.LogSoftmax(1),
        )

    def forward(self, x):
        x = self.model(x)
        return x

The following code shows how you can use a PyTorch Lightning module to wrap your model class and describe the behaviour for training and (optionally) validation steps.
We also use the `LightningModule`'s builtin methods to configure the optimiser.
For more information, see the [PyTorch Lightning Documentation for pl.LightningModule](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html).

In [None]:
# ResNet18 from-scratch with some changes to suit the FashionMNIST dataset
class ResNetClassifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        x = self.model(x)
        return x

    def training_step(self, batch, _):
        x, y = batch
        output = self.forward(x)
        loss = torch.nn.functional.nll_loss(output, y)
        return loss

    def validation_step(self, batch, _):
        x, y = batch
        output = self.forward(x)
        preds = torch.argmax(output, dim=1)
        acc = torch.sum(preds == y).float() / len(y)
        return acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

This class defines how to feed data to the model. It gets data from the local directory fashionmnist_data/ and declares a dataloader for training and for validation, based on the IPU-specific poptoch dataloader (https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/reference.html#poptorch.DataLoade). This will be passed to the trainer below.

In [None]:
class FashionMNIST(pl.LightningDataModule):
    def __init__(self, options, batch_size=4):
        super().__init__()
        self.batchsize = batch_size
        self.options = options

    def setup(self, stage="train"):
        # Retrieving the datasets
        self.train_data = torchvision.datasets.FashionMNIST(
            dataset_directory,
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

        self.validation_data = torchvision.datasets.FashionMNIST(
            dataset_directory,
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

    def train_dataloader(self):
        return poptorch.DataLoader(
            dataset=self.train_data,
            batch_size=self.batchsize,
            options=self.options,
            shuffle=True,
            drop_last=True,
            mode=poptorch.DataLoaderMode.Async,
            num_workers=64,
        )

    def val_dataloader(self):
        return poptorch.DataLoader(
            dataset=self.validation_data,
            batch_size=self.batchsize,
            options=self.options,
            drop_last=True,
            mode=poptorch.DataLoaderMode.Async,
            num_workers=64,
        )

Set up training. Choosing to use the model defined above 'ResNetFromScratch', but can also use the alternative "TorchVisionBackbone"

In [None]:
model = ResNetFromScratch()

Set up many epochs to train for

In [None]:
num_epochs = 2

Pass the model to the PT-Lightning classifier and create a trainer with some IPU-spcifc options

In [None]:
model = ResNetClassifier(model)

options = poptorch.Options()
options.deviceIterations(250)
options.replicationFactor(number_of_ipus)

datamodule = FashionMNIST(options)

trainer = pl.Trainer(
    accelerator="ipu",
    devices=number_of_ipus,
    max_epochs=num_epochs,
    log_every_n_steps=1,
    accumulate_grad_batches=8,
    strategy=IPUStrategy(inference_opts=options, training_opts=options),
    # enable_progress_bar=False,
)

Now train the model

In [None]:
trainer.fit(model, datamodule)