# Checkpointing with PyTorch
In this notebook we will go through checkpointing your model with PyTorch.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models.resnet import BasicBlock, ResNet

In [3]:
# Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])
data_path = "/cephyr/NOBACKUP/Datasets/tiny-imagenet-200/"
test_set   = ImageFolder(data_path + "test",  transform=transform)
val_set    = ImageFolder(data_path + "val",   transform=transform)
train_set  = ImageFolder(data_path + "train", transform=transform)

load_kws = dict(
    num_workers = 4,
    batch_size = 512,
    prefetch_factor = 128,
)
test_loader  = DataLoader(test_set,  shuffle=False, **load_kws, drop_last=False)
val_loader   = DataLoader(val_set,   shuffle=False, **load_kws)
train_loader = DataLoader(train_set, shuffle=True,  **load_kws)

# ResNet-18 
model = ResNet(BasicBlock, layers=[2, 2, 2, 2], num_classes=200)

# Optimizer
opt = optim.Adam(model.parameters())

Now we come to the important part, the training. In this part we will have to include the checkpointing steps.

In [4]:
loss_func = nn.CrossEntropyLoss()

def train(model, opt, n_epochs, checkpoint_path):
    
    n_batches = len(train_loader)
    progress_bar = tqdm(total=n_epochs*n_batches)
    
    for epoch in range(n_epochs):
        
        # Training epoch
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            
            opt.zero_grad()
            
            est = model(images)
            
            loss = loss_func(est, labels)
            loss.backward()
            train_loss += loss.item()
            
            progress_bar.step()
        
        train_loss /= n_batches
        
        # Validation
        val_loss, val_acc = validate(model)
        print(f"Epoch {epoch}, Train loss {train_loss}, Val loss {val_loss}, Val acc {val_acc}")

        # TODO add checkpointing below
    
    progress_bar.close()
    return model
    
def validate(model):
    model.eval()
    with torch.no_grad():
        loss = 0.0
        n_batches = len(val_loader)
        for images, labels in val_loader:
            est = model(images)
            loss += loss_func(est, labels).item()
            acc = (labels == est.max(1).item()) / labels.size(0)
            
        loss /= n_batches


In [None]:
train(model, opt, 1, None)