# Checkpointing with PyTorch
In this notebook we will go through checkpointing your model with PyTorch.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

In [4]:
# Here we move the dataset to TMPDIR if one is available
import os

if "TMPDIR" in os.environ:
    data_path = os.path.join(os.environ["TMPDIR"], "tiny-imagenet-200/")
    if not os.path.isdir(data_path):
        !cp "/cephyr/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip" "$TMPDIR"
        !unzip -qn "$TMPDIR/tiny-imagenet-200.zip" -d "$TMPDIR"
else:
    data_path = "/cephyr/NOBACKUP/Datasets/tiny-imagenet-200"


In [5]:
import csv

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18

In [6]:
# Dataset
class TinyImageNet(ImageFolder):
    '''The directory structure of val is a bit off.
    
    To fix so that this works with the validation set
    '''
    
    def __init__(self, root_parent, type, *args, **kwargs):
        self.type = type
        self.root_parent = root_parent
        super().__init__(os.path.join(self.root_parent, self.type), *args, **kwargs)
    
    def make_dataset(self, directory, class_to_idx, *args, **kwargs):
        """Generates a list of samples of a form (path_to_sample, class)."""
        if class_to_idx is None:
            raise ValueError("The parameter class_to_idx cannot be None.")
        
        if self.type == "train":
            return super().make_dataset(directory, class_to_idx, *args, **kwargs)
        
        with open(os.path.join(self.root, self.type + "_annotations.txt"), "r") as f:
            return [
                (os.path.join(self.root, "images", fn), class_to_idx[class_name])
                for fn, class_name, _, _, _, _
                in csv.reader(f, delimiter="\t")
            ]
    
    def _find_classes(self, directory):
        """List of all classes and dictionary mapping each class to an index."""
        train_dir = os.path.join(directory, "..", "train")
        return super()._find_classes(train_dir)

transform = transforms.Compose([
    transforms.ToTensor(),
])
val_set    = TinyImageNet(data_path, "val",   transform=transform)
train_set  = TinyImageNet(data_path, "train", transform=transform)

load_kws = dict(
    num_workers = 4,
    batch_size = 512,
    prefetch_factor = 512,
)
val_loader   = DataLoader(val_set,   shuffle=False, **load_kws)
train_loader = DataLoader(train_set, shuffle=True,  **load_kws)

# ResNet-18
pretrained = False
model = resnet18(pretrained=False, num_classes=200)
if pretrained:
    # If we like we can use weights trained on ImageNet 1000
    pretrained_state_dict = resnet18(pretrained=pretrained, num_classes=1000).state_dict()
    # However, the last fully connected layer is the wrong shape    
    for key in ["fc.weight", "fc.bias"]:
        del pretrained_state_dict[key]
    model.load_state_dict(pretrained_state_dict, strict=False)

# Optimizer
opt = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)

Now we come to the important part, the training. In this part we will have to include the checkpointing steps.

In [7]:
loss_func = nn.CrossEntropyLoss()
device = torch.device("cuda")

def train(model, opt, n_epochs, checkpoint_path, device=device):
    model = model.to(device)
    
    n_batches = len(train_loader)
    total_steps = n_epochs * n_batches
    counter = 0
    
    for epoch in range(n_epochs):
        
        # Training epoch
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            opt.zero_grad()
            
            est = model(images)
            
            loss = loss_func(est, labels)
            loss.backward()
            opt.step()
            train_loss += loss.item()
            
            counter += 1
            print(f"\rProgress: {100 * counter / total_steps:4.1f} %  ({counter}/{total_steps})", end="")
            
        train_loss /= n_batches
        
        # Validation
        val_loss, val_acc = validate(model, device=device)
        print(f"\rEpoch {epoch}, Train loss {train_loss}, Val loss {val_loss}, Val acc {val_acc}")

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": opt.state_dict(),
        }, checkpoint_path)
        
        
def validate(model, device=device):
    model.to(device)
    model.eval()
    with torch.no_grad():
        loss = 0.0
        n_batches = len(val_loader)
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            est = model(images)
            loss += loss_func(est, labels).item()
            acc = (labels == est.argmax(1)).float().mean().item()
        
        loss /= n_batches
        
        return loss, acc


In [8]:
%%time
train(model, opt, 5, checkpoint_path="checkpoint.pt")

Epoch 0, Train loss 5.124734632822932, Val loss 4.875583219528198, Val acc 0.0625
Epoch 1, Train loss 4.619282299158525, Val loss 4.440532326698303, Val acc 0.12867647409439087
Epoch 2, Train loss 4.199817411753596, Val loss 4.246251940727234, Val acc 0.16544117033481598
Epoch 3, Train loss 3.871709726294693, Val loss 4.439205241203308, Val acc 0.14705882966518402
Epoch 4, Train loss 3.6147326303988088, Val loss 3.8597879290580748, Val acc 0.1875
CPU times: user 2min 56s, sys: 1min 12s, total: 4min 9s
Wall time: 4min 32s


## Loading from checkpoint
Now that we have created a checkpointed we want to load it to check how it performs against the validation set again.

In [9]:
model = resnet18(pretrained=False, num_classes=200)
checkpoint = torch.load("checkpoint.pt")
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [10]:
loss, acc = validate(model)
print(f'''
Validation loss: {loss:.4f}
Accuracy:        {acc:.4f}''')


Validation loss: 3.8598
Accuracy:        0.1875


## Excercises
1. Write a `train_from_checkpoint` function below that given the path to a checkpoint continues training from there
2. Modify the `train_from_checkpoint` function to also save the best checkpoint so far