# Checkpointing with PyTorch
In this notebook we will go through checkpointing your model with PyTorch.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

In [1]:
# Here we move the dataset to TMPDIR if one is available
import os

if "TMPDIR" in os.environ:
    data_path = os.path.join(os.environ["TMPDIR"], "tiny-imagenet-200/")
    if not os.path.isdir(data_path):
        !cp "/cephyr/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip" "$TMPDIR"
        !unzip -n "$TMPDIR/tiny-imagenet-200.zip" -d "$TMPDIR"
else:
    data_path = "/cephyr/NOBACKUP/Datasets/tiny-imagenet-200"


In [2]:
import csv

import torch
from torch import nn, optim, profiler
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import ImageFolder

In [4]:
# Dataset
class TinyImageNet(ImageFolder):
    '''The directory structure of val is a bit off.
    
    To fix so that this works with the validation set
    '''
    
    def __init__(self, root_parent, type, *args, **kwargs):
        self.type = type
        self.root_parent = root_parent
        super().__init__(os.path.join(self.root_parent, self.type), *args, **kwargs)
    
    def make_dataset(self, directory, class_to_idx, *args, **kwargs):
        """Generates a list of samples of a form (path_to_sample, class)."""
        if class_to_idx is None:
            raise ValueError("The parameter class_to_idx cannot be None.")
        
        if self.type == "train":
            return super().make_dataset(directory, class_to_idx, *args, **kwargs)
        
        with open(os.path.join(self.root, self.type + "_annotations.txt"), "r") as f:
            return [
                (os.path.join(self.root, "images", fn), class_to_idx[class_name])
                for fn, class_name, _, _, _, _
                in csv.reader(f, delimiter="\t")
            ]
    
    def find_classes(self, directory):
        """List of all classes and dictionary mapping each class to an index."""
        train_dir = os.path.join(directory, "..", "train")
        if hasattr(ImageFolder, "find_classes"):
            return super().find_classes(train_dir)
        else:
            return super()._find_classes(train_dir)
            
    
    def _find_classes(self, directory):
        """Backwards compatability, see find_classes."""
        return self.find_classes(directory)

    
transform = transforms.Compose([
    transforms.ToTensor(),
])
val_set    = TinyImageNet(data_path, "val",   transform=transform)
train_set  = TinyImageNet(data_path, "train", transform=transform)

val_loader   = DataLoader(val_set,   shuffle=False)
train_loader = DataLoader(train_set, shuffle=True)

# ResNet-18
pretrained = True
model = resnet18(pretrained=False, num_classes=200)
if pretrained:
    # If we like we can use weights trained on ImageNet 1000
    pretrained_state_dict = resnet18(
        pretrained=pretrained,
        num_classes=1000,
        progress=False,
    ).state_dict()
    # However, the last fully connected layer is the wrong shape    
    for key in ["fc.weight", "fc.bias"]:
        del pretrained_state_dict[key]
    model.load_state_dict(pretrained_state_dict, strict=False)

# Optimizer
opt = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)

# Other
loss_func = nn.CrossEntropyLoss()
device = torch.device("cuda")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /cephyr/users/vikren/Alvis/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


Having taken care of these initialisations we are ready to take a look at profiling.

In [5]:
model.to(device)

def train_step(images, labels):
    images = images.to(device)
    labels = labels.to(device)
    opt.zero_grad()
    
    est = model(images)
    
    loss = loss_func(est, labels)
    loss.backward()
    opt.step()
    
    return loss.item()

In [6]:
with profiler.profile(
        schedule=profiler.schedule(wait=10, warmup=5, active=10, repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/base.ptb'),
        record_shapes=False,
        profile_memory=False,
        with_stack=False,
) as prof:

    for images, labels in train_loader:
        loss = train_step(images, labels)
        
        # This informs the profile scheduler
        prof.step()
        
        print(f"\rStep: {prof.step_num}/50", end="")
        if prof.step_num >= 50:
            break

Step: 50/50

## Excercises
1. Look at the profiling results in tensorboard. To do this, follow the instructions in README.md
2. Try to follow the Performance Recomendation and try again with the code below

In [6]:
train_loader = DataLoader(train_set, shuffle=True)
model.to(device)

with profiler.profile(
        schedule=profiler.schedule(wait=10, warmup=5, active=10, repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/improved.ptb'),
        record_shapes=False,
        profile_memory=False,
        with_stack=False,
) as prof:

    for images, labels in train_loader:
        loss = train_step(images, labels)
        
        # This informs the profile scheduler
        prof.step()
        
        print(f"\rStep: {prof.step_num}/50", end="")
        if prof.step_num >= 50:
            # Part of an epoch may be enough information for us
            break

Step: 50/50