# Checkpointing with PyTorch
In this notebook we will go through checkpointing your model with PyTorch.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64), fewer images (100 k) and fewer labels (200). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

In [11]:
import os
import zipfile
from fnmatch import fnmatch
from typing import Tuple

import numpy as np
import torch
from PIL import Image

# Custom Type Hints https://peps.python.org/pep-0484/
LoadedFromZip = Tuple[str, object]  
DataPoint = Tuple[torch.FloatTensor, int]

In [12]:
def build_dataset(split='train'):
    '''Construct a dataset for the tiny-imagenet-200 dataset'''
    path_to_dataset = '/mimer/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip'

    # Open the zip file
    ziphandle = zipfile.ZipFile(path_to_dataset)
    
    # Filter data images based on the split (train and val)
    filenames = [
        filename for filename in ziphandle.namelist()
        if f'/{split}/' in filename and filename.endswith('.JPEG')
    ]
    
    # Set length of dataset
    dataset_len = len(filenames)

    # Read wnids.txt to create label mapping
    for filename in ziphandle.namelist():
        if filename.endswith('wnids.txt'):
            with ziphandle.open(filename) as txtfile:
                wnids = txtfile.read().decode('utf-8').split()
                break
    wnid2label = {wnid: label for label, wnid in enumerate(wnids)}

    # Utility function for getting wnid from filename
    if split == 'train':
        def get_wnid(filename: str) -> str:
            return filename.split("/")[-1].split('_')[0]
    elif split == 'val':
        # Parse annotations in validation set
        filename2wnid = {}
        for filename in ziphandle.namelist():
            if filename.endswith('val_annotations.txt'):
                with ziphandle.open(filename) as txtfile:
                    for line in txtfile.read().decode('utf-8').split('\n'):
                        if line.startswith('val'):
                            fname, wnid, *_ = line.split('\t')
                            filename2wnid[fname] = wnid
                break
        
        def get_wnid(filename: str) -> str:
            return filename2wnid.get(os.path.basename(filename), None)

    else:
        raise NotImplementedError(f"Can't determine labels for split {split}.")

    # Convert stream to image tensor and label
    def parse_tiny_imagenet(filename: str) -> DataPoint:
        '''Parse filename and image stream into label and image tensors'''
        wnid = get_wnid(filename)
        label = wnid2label.get(wnid, -1)

        with ziphandle.open(filename) as stream:
            img_array = np.array(Image.open(stream))
            if img_array.ndim < 3:
                # Greyscale to RGB
                img_array = np.repeat(img_array[..., np.newaxis], 3, -1)

            img_tensor = torch.from_numpy(img_array)
            img_tensor = img_tensor.permute(2, 0, 1)  # Convert to (C, H, W)
            return img_tensor.float(), label

    # Apply the parse function to all filenames
    dataset = [parse_tiny_imagenet(filename) for filename in filenames]
    
    return dataset



In [15]:
import csv
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18

In [16]:
train_dataset = build_dataset(split="train")
val_dataset = build_dataset(split="val")

load_kws = dict(
    num_workers=4,
    batch_size=512,
    prefetch_factor=512,
)

train_loader = DataLoader(train_dataset, shuffle=True, **load_kws)
val_loader = DataLoader(val_dataset, shuffle=False, **load_kws)

# ResNet-18
pretrained = False
model = resnet18(weights=None, num_classes=200)
if pretrained:
    # If we like we can use weights trained on ImageNet 1000
    pretrained_state_dict = resnet18(weights="IMAGENET1K_V2", num_classes=1000).state_dict()
    # However, the last fully connected layer is the wrong shape    
    for key in ["fc.weight", "fc.bias"]:
        del pretrained_state_dict[key]
    model.load_state_dict(pretrained_state_dict, strict=False)

# Optimizer
opt = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)


Now we come to the important part, the training. In this part we will have to include the checkpointing steps.

In [17]:
loss_func = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, opt, n_epochs, checkpoint_path, device=device):
    model = model.to(device)
    
    n_batches = len(train_loader)
    total_steps = n_epochs * n_batches
    counter = 0
    
    for epoch in range(n_epochs):
        
        # Training epoch
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            opt.zero_grad()
            
            # Forward pass
            est = model(images)
            
            # Calculate loss and backpropagate
            loss = loss_func(est, labels)
            loss.backward()
            opt.step()
            train_loss += loss.item()
            
            counter += 1
            print(f"\rProgress: {100 * counter / total_steps:4.1f} %  ({counter}/{total_steps})", end="")
        
        # Average training loss
        train_loss /= n_batches
        
        # Validation
        val_loss, val_acc = validate(model, device=device)
        print(f"\rEpoch {epoch}, Train loss {train_loss}, Val loss {val_loss}, Val acc {val_acc}")

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": opt.state_dict(),
        }, checkpoint_path)


def validate(model, device=device):
    model.to(device)
    model.eval()
    with torch.no_grad():
        loss = 0.0
        n_batches = len(val_loader)
        correct = 0
        total = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            est = model(images)
            loss += loss_func(est, labels).item()
            acc = (labels == est.argmax(1)).float().mean().item()
            
            # Accuracy calculation
            #_, predicted = torch.max(est, 1)
            #correct += (predicted == labels).sum().item()
            #total += labels.size(0)
        
        # Average loss and accuracy
        loss /= n_batches
        #accuracy = correct / total

        return loss, acc


In [18]:
%%time
train(model, opt, n_epochs=5, checkpoint_path="checkpoint.pt")

Epoch 0, Train loss 5.111716720522667, Val loss 4.836109328269958, Val acc 0.04411764815449715
Epoch 1, Train loss 4.5856974319535855, Val loss 4.557662415504455, Val acc 0.06617647409439087
Epoch 2, Train loss 4.188210615089962, Val loss 4.187456750869751, Val acc 0.12132353335618973
Epoch 3, Train loss 3.843514551921767, Val loss 3.8874372959136965, Val acc 0.17279411852359772
Epoch 4, Train loss 3.5770948091331793, Val loss 3.718472933769226, Val acc 0.16544117033481598
CPU times: user 2min 26s, sys: 10.2 s, total: 2min 36s
Wall time: 2min 45s


In [19]:
model = resnet18(weights=None, num_classes=200)
checkpoint = torch.load("checkpoint.pt")
model.load_state_dict(checkpoint["model_state_dict"])


<All keys matched successfully>

In [None]:
# Ensure the model is in evaluation mode
# model.eval()

In [20]:
loss, acc = validate(model)
print(f'''
Validation loss: {loss:.4f}
Accuracy:        {acc:.4f}''')


Validation loss: 3.7185
Accuracy:        0.1654


## Excercises
1. Write a `train_from_checkpoint` function below that given the path to a checkpoint continues training from there
2. Modify the `train_from_checkpoint` function to also save the best checkpoint so far