# Checkpointing with PyTorch
In this notebook we will go through checkpointing your model with PyTorch.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64), fewer images (100 k) and fewer labels (200). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

### Datapipe
First we construct a utility function to yield datapipes to later use in our DataLoaderhttps://tagtidtabeller.resrobot.se/tidtabell/tag35_34197.pdf

In [1]:
import os
import zipfile
from fnmatch import fnmatch
from typing import Tuple

import numpy as np
import torch
from PIL import Image
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import FileOpener, IterDataPipe
from torch.utils.data.datapipes.utils.common import StreamWrapper

# Custom Type Hints https://peps.python.org/pep-0484/
LoadedFromZip = Tuple[str, StreamWrapper]
DataPoint = Tuple[torch.FloatTensor, int]

In [2]:
# Manually set length (will not affect how many elements that can be yielded)
# in future use https://pytorch.org/data/main/generated/torchdata.datapipes.iter.LengthSetter.html
@functional_datapipe('set_length')
class LengthSetterIterDataPipe(IterDataPipe):
    def __init__(self, source_datapipe: IterDataPipe, length: int) -> None:
        self.source_datapipe = source_datapipe
        assert length >= 0
        self.length = length

    def __iter__(self) -> IterDataPipe:
        yield from self.source_datapipe

    def __len__(self) -> int:
        return self.length


In [3]:
def build_datapipe(split='train') -> IterDataPipe:
    '''Construct a datapipe for the tiny-imagenet-200 dataset'''
    path_to_dataset = '/mimer/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip'
    fileopener = FileOpener([path_to_dataset], mode='b')

    # Construct datapipe to load images
    datapipe = fileopener.load_from_zip()
    def train_image_filter(entry: LoadedFromZip) -> bool:
        '''Filter for .JPEG in train/val folder'''
        filename, _ = entry
        return fnmatch(filename, f'*{split}*.JPEG')
    datapipe = datapipe.filter(train_image_filter)

    # Set length of dataset
    ziphandle = zipfile.ZipFile(path_to_dataset)
    dataset_len = len([
        filename for filename in ziphandle.namelist()
        if '/' + split + '/' in filename and filename.endswith('.JPEG')
    ])

    datapipe = datapipe.set_length(dataset_len)
    
    # Enable shuffle and multiple workers
    datapipe = datapipe.shuffle()
    datapipe = datapipe.sharding_filter()

    # Give word name ids numeric labels 0-199
    for filename, txtfile in fileopener.load_from_zip():
        if filename.endswith('wnids.txt'):
            wnids = txtfile.read().decode('utf-8').split()
            break
    wnid2label = {wnid: label for label, wnid in enumerate(wnids)}

    # Utility function from getting word name id from filename
    if split=='train':
        def get_wnid(filename: str) -> str:
            return filename.split("/")[-1].split('_')[0]

    elif split=='val':
        # Parse annotations in validation set
        for filename, txtfile in fileopener.load_from_zip():
            if filename.endswith('val_annotations.txt'):
                # filename, wnid, ?, ?, ?, ?
                filename2wnid = dict([
                    tuple(line.split('\t')[:2])
                    for line in txtfile.read().decode('utf-8').split('\n')
                    if line.startswith('val')
                ])
                break

        def get_wnid(filename: str) -> str:
            nonlocal filename2wnid
            return filename2wnid[os.path.basename(filename)]

    else:
        raise NotImplementedError(f"Can't determine labels for split {split}.")

    # Convert stream to image tensor and label
    def parse_tiny_imagenet(entry: LoadedFromZip) -> DataPoint:
        '''Parse filename and image stream into label and image tensors'''
        filename, stream = entry

        # Get label from filename
        wnid = get_wnid(filename)
        label = wnid2label[wnid]

        # Parse image into Tensor of size (Channel, Px, Py)
        img_array = np.array(Image.open(stream))
        if img_array.ndim < 3:
            # Greyscale to RGB
            img_array = np.repeat(img_array[..., np.newaxis], 3, -1)

        img_tensor = torch.from_numpy(img_array)
        img_tensor = img_tensor.permute(2,0,1)
        return img_tensor.float(), label
    datapipe = datapipe.map(parse_tiny_imagenet)
    
    # Set length of dataset. In future
    # https://pytorch.org/data/main/generated/torchdata.datapipes.iter.LengthSetter.html
    ziphandle = zipfile.ZipFile(path_to_dataset)
    dataset_len = len([
        filename for filename in ziphandle.namelist()
        if '/' + split + '/' in filename and filename.endswith('.JPEG')
    ])

    if not hasattr(datapipe, 'set_length'):
        @functional_datapipe('set_length')
        class LengthSetterIterDataPipe(IterDataPipe):
  
            def __init__(self, source_datapipe: IterDataPipe, length: int):
                self.source_datapipe = source_datapipe
                assert length >= 0
                self.length = length

            def __iter__(self):
                yield from self.source_datapipe

            def __len__(self) -> int:
                return self.length
    datapipe = datapipe.set_length(dataset_len)
    
    return datapipe

In [4]:
import csv

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18

In [5]:
valpipe    = build_datapipe(split="val")
trainpipe  = build_datapipe(split="train")

load_kws = dict(
    num_workers = 4,
    batch_size = 512,
    prefetch_factor = 512,
)
val_loader   = DataLoader(valpipe,   shuffle=False, **load_kws)
train_loader = DataLoader(trainpipe, shuffle=True,  **load_kws)

# ResNet-18
pretrained = False
model = resnet18(weights=None, num_classes=200)
if pretrained:
    # If we like we can use weights trained on ImageNet 1000
    pretrained_state_dict = resnet18(weights="IMAGENET1K_V2", num_classes=1000).state_dict()
    # However, the last fully connected layer is the wrong shape    
    for key in ["fc.weight", "fc.bias"]:
        del pretrained_state_dict[key]
    model.load_state_dict(pretrained_state_dict, strict=False)

# Optimizer
opt = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)

Now we come to the important part, the training. In this part we will have to include the checkpointing steps.

In [6]:
loss_func = nn.CrossEntropyLoss()
device = torch.device("cuda")

def train(model, opt, n_epochs, checkpoint_path, device=device):
    model = model.to(device)
    
    n_batches = len(train_loader)
    total_steps = n_epochs * n_batches
    counter = 0
    
    for epoch in range(n_epochs):
        
        # Training epoch
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            opt.zero_grad()
            
            est = model(images)
            
            loss = loss_func(est, labels)
            loss.backward()
            opt.step()
            train_loss += loss.item()
            
            counter += 1
            print(f"\rProgress: {100 * counter / total_steps:4.1f} %  ({counter}/{total_steps})", end="")
            
        train_loss /= n_batches
        
        # Validation
        val_loss, val_acc = validate(model, device=device)
        print(f"\rEpoch {epoch}, Train loss {train_loss}, Val loss {val_loss}, Val acc {val_acc}")

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": opt.state_dict(),
        }, checkpoint_path)
        
        
def validate(model, device=device):
    model.to(device)
    model.eval()
    with torch.no_grad():
        loss = 0.0
        n_batches = len(val_loader)
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            est = model(images)
            loss += loss_func(est, labels).item()
            acc = (labels == est.argmax(1)).float().mean().item()
        
        loss /= n_batches
        
        return loss, acc


In [7]:
%%time
train(model, opt, 5, checkpoint_path="checkpoint.pt")

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

## Loading from checkpoint
Now that we have created a checkpointed we want to load it to check how it performs against the validation set again.

In [8]:
model = resnet18(weights=None, num_classes=200)
checkpoint = torch.load("checkpoint.pt")
model.load_state_dict(checkpoint["model_state_dict"])

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoint.pt'

In [None]:
loss, acc = validate(model)
print(f'''
Validation loss: {loss:.4f}
Accuracy:        {acc:.4f}''')

## Excercises
1. Write a `train_from_checkpoint` function below that given the path to a checkpoint continues training from there
2. Modify the `train_from_checkpoint` function to also save the best checkpoint so far