# Checkpointing with PyTorch
In this notebook we will go through checkpointing your model with PyTorch.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64), fewer images (100 k) and fewer labels (200). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

In [None]:
import os
import torch
from torchvision.models import resnet18
from pytorch_dataset import TinyImageNetDataset 
from torch import nn, optim
from torch.utils.data import DataLoader

# For performance set precision,
# see https://www.c3se.chalmers.se/documentation/applications/pytorch/#performance-and-precision
torch.set_float32_matmul_precision("high")

In [None]:
# Use the custom dataset class from pytorch_dataset.py
train_dataset = TinyImageNetDataset(path_to_dataset='/mimer/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip', split='train')
val_dataset = TinyImageNetDataset(path_to_dataset='/mimer/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip', split='val')


load_kws = dict(
    num_workers=4,
    batch_size=512,
    prefetch_factor=512,
)

train_loader = DataLoader(train_dataset, shuffle=True, **load_kws)
val_loader = DataLoader(val_dataset, shuffle=False, **load_kws)

# ResNet-18 setup
pretrained = False
model = resnet18(weights=None, num_classes=200)
if pretrained:
    pretrained_state_dict = resnet18(weights="IMAGENET1K_V2", num_classes=1000).state_dict()
    for key in ["fc.weight", "fc.bias"]:
        del pretrained_state_dict[key]
    model.load_state_dict(pretrained_state_dict, strict=False)

# Optimizer setup
opt = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)


Now we come to the important part, the training. In this part we will have to include the checkpointing steps.

In [None]:
loss_func = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, opt, n_epochs, checkpoint_path, device=device):
    model = model.to(device)
    
    n_batches = len(train_loader)
    total_steps = n_epochs * n_batches
    counter = 0
    
    for epoch in range(n_epochs):
        
        # Training epoch
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            opt.zero_grad()
            
            # Forward pass
            est = model(images)
            
            # Calculate loss and backpropagate
            loss = loss_func(est, labels)
            loss.backward()
            opt.step()
            train_loss += loss.item()
            
            counter += 1
            print(f"\rProgress: {100 * counter / total_steps:4.1f} %  ({counter}/{total_steps})", end="")
        
        # Average training loss
        train_loss /= n_batches
        
        # Validation
        val_loss, val_acc = validate(model, device=device)
        print(f"\rEpoch {epoch}, Train loss {train_loss}, Val loss {val_loss}, Val acc {val_acc}")

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": opt.state_dict(),
        }, checkpoint_path)


def validate(model, device=device):
    model.to(device)
    model.eval()
    with torch.no_grad():
        loss = 0.0
        n_batches = len(val_loader)
        correct = 0
        total = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            est = model(images)
            loss += loss_func(est, labels).item()
            #acc = (labels == est.argmax(1)).float().mean().item()
            
            # Accuracy calculation
            _, predicted = torch.max(est, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        
        # Average loss and accuracy
        loss /= n_batches
        acc = correct / total

        return loss, acc


In [None]:
%%time
train(model, opt, n_epochs=5, checkpoint_path="checkpoint.pt")

In [None]:
model = resnet18(weights=None, num_classes=200)
checkpoint = torch.load("checkpoint.pt")
model.load_state_dict(checkpoint["model_state_dict"])


In [None]:
# Ensure the model is in evaluation mode
# model.eval()

In [None]:
loss, acc = validate(model)
print(f'''
Validation loss: {loss:.4f}
Accuracy:        {acc:.4f}''')

## Excercises
1. Write a `train_from_checkpoint` function below that given the path to a checkpoint continues training from there
2. Modify the `train_from_checkpoint` function to also save the best checkpoint so far