# Checkpointing with PyTorch
In this notebook we will go through checkpointing your model with PyTorch.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

In [1]:
import torch
from torch import nn

In [22]:
# This code is adapted from https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html

class ResidualBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        
        identity = input if self.downsample is None else self.downsample(input)
        
        return self.relu(x + identity)


class ResNet(nn.Module):

    def __init__(
            self,
            layers,
            num_classes=200,
            zero_init_residual=False,
            groups=1,
            downsample=None,
        ):
        super().__init__()
        self.block = ResidualBlock
        norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.channels = 64
        self.dilation = 1
        self.groups = groups
        self.conv1 = nn.Conv2d(
            3, self.channels, kernel_size=7, stride=2, padding=3, bias=False,
        )
        self.bn1 = norm_layer(self.channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, layers[0])
        self.layer2 = self._make_layer(128, layers[1], stride=2)
        self.layer3 = self._make_layer(256, layers[2], stride=2)
        self.layer4 = self._make_layer(512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, out_channels, blocks, stride=1):
        block = self.block
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if stride != 1 or self.channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.channels, out_channels, 1, stride=stride, bias=False,
                ),
                norm_layer(out_channels),
            )

        layers = []
        layers.append(block(self.channels, out_channels, stride=stride, downsample=downsample))
        self.channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(self.channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)


In [23]:
resnet = ResNet([2, 2, 2, 2])
x = torch.randn(10, 3, 1024, 1024)
resnet(x)

torch.Size([10, 200])

In [28]:
from torchvision.models.resnet import BasicBlock, ResNet
ResNet(BasicBlock, [2, 2, 2, 2], num_classes=200)(x)

torch.Size([10, 200])