# Checkpointing with PyTorch
In this notebook we will go through checkpointing your model with PyTorch.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

In [95]:
# Here we move the dataset to TMPDIR if one is available
import os

if "TMPDIR" in os.environ:
    data_path = os.path.join(os.environ["TMPDIR"], "tiny-imagenet-200/")
    if not os.path.isdir(data_path):
        !cp "/cephyr/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip" "$TMPDIR"
        !unzip -n "$TMPDIR/tiny-imagenet-200.zip" -d "$TMPDIR"
else:
    data_path = "/cephyr/NOBACKUP/Datasets/tiny-imagenet-200"


In [96]:
import csv
from typing import Iterable

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Input, layers, Sequential
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.data import Dataset
from tensorflow.keras.preprocessing.image import ImageDataGenerator, DirectoryIterator

In [97]:
class TinyImageNetIterator(DirectoryIterator):
    '''Help class when loading TinyImageNet.'''
    
    def __init__(
        self,
        parent_directory,
        subset,
        image_data_generator,
        target_size=(64, 64),
        color_mode='rgb',
        classes=None,
        class_mode='categorical',
        batch_size=32,
        shuffle=True,
        seed=None,
        data_format='channels_last',
        save_to_dir=None,
        save_prefix='',
        save_format='png',
        follow_links=False,
        interpolation='nearest',
        dtype='float32',
    ):
        train_directory = os.path.join(parent_directory, "train")
        if subset=="training":
            return super().__init__(
                train_directory,
                image_data_generator,
                target_size=target_size,
                color_mode=color_mode,
                classes=classes,
                class_mode=class_mode,
                batch_size=batch_size,
                shuffle=shuffle,
                seed=seed,
                data_format=data_format,
                save_to_dir=save_to_dir,
                save_prefix=save_prefix,
                save_format=save_format,
                follow_links=follow_links,
                subset=subset,
                interpolation=interpolation,
                dtype=dtype,
            )
        elif subset=="validation":
            directory = os.path.join(parent_directory, "val")
        else:
            raise ValueError(f'Value of subset should be "training" or "validation",  not {subset}.')
        
        # Modified Directory Iterator __init__
        super(DirectoryIterator, self).set_processing_attrs(
            image_data_generator=image_data_generator,
            target_size=target_size,
            color_mode=color_mode,
            data_format=data_format,
            save_to_dir=save_to_dir,
            save_prefix=save_prefix,
            save_format=save_format,
            subset=subset,
            interpolation=interpolation,
        )
        self.directory = directory
        self.classes = classes
        if class_mode not in self.allowed_class_modes:
            raise ValueError(f'Invalid class_mode: {class_mode}; expected one of: {self.allowed_class_modes}')
        self.class_mode = class_mode
        self.dtype = dtype
        
        # First, count the number of samples and classes.
        class_names = classes
        if not class_names:
            class_names = []
            for subdir in sorted(os.listdir(train_directory)):
                if os.path.isdir(os.path.join(train_directory, subdir)):
                    class_names.append(subdir)
        self.num_classes = len(class_names)
        self.class_indices = dict(zip(class_names, range(len(class_names))))

        # Get map between filename and class index                 
        with open(os.path.join(directory, "val_annotations.txt"), "r") as f:
            filenames, classes = zip(*[
                (fn, self.class_indices[class_name])
                for fn, class_name, _, _, _, _
                in csv.reader(f, delimiter="\t")
            ])

        self.filenames = filenames
        self.samples = len(self.filenames)
        self.classes = np.array(classes, dtype='int32')

        print(f'Found {self.samples} images belonging to {self.num_classes} classes.')
        self._filepaths = [os.path.join(self.directory, fn) for fn in self.filenames]
        grandparent_class = self.__class__.__mro__[2]  # sorry, not nice code
        super(grandparent_class, self).__init__(self.samples, batch_size, shuffle, seed)        
        
    
    def __len__(self):
        return self.n // self.batch_size + 1
        

class TinyImageNetGenerator(ImageDataGenerator, TinyImageNetIterator):
    
    
    def __bool__(self):
        return True
    
    
    def flow_from_directory(
        self,
        parent_directory,
        subset,
        *,
        target_size=(64, 64),
        color_mode="rgb",
        **kwargs,
    ):
        return TinyImageNetIterator(
            parent_directory,
            subset,
            self,
            target_size=target_size,
            color_mode=color_mode,
            **kwargs
        )

    

In [98]:
dir_parent = '/cephyr/NOBACKUP/Datasets/tiny-imagenet-200/'

train_batches = TinyImageNetGenerator().flow_from_directory(dir_parent, "training", batch_size=128)
val_batches = TinyImageNetGenerator().flow_from_directory(dir_parent, "validation", batch_size=128)

train_set = Dataset.from_generator(
    lambda: train_batches,
    output_types=(tf.float32, tf.float32),
    output_shapes=([None, 64, 64, 3], [None, 200])
)
val_set = Dataset.from_generator(
    lambda: val_batches,
    output_types=(tf.float32, tf.float32),
    output_shapes=([None, 64, 64, 3], [None, 200])
)


Found 100000 images belonging to 200 classes.
Found 10000 images belonging to 200 classes.


In [99]:
class ResidualBlock(layers.Layer):
    
    def __init__(self, filters, strides=1, downsample=None):
        super().__init__()
        self.downsample = downsample
        
        self.relu = layers.ReLU()
        self.conv1 = layers.Conv2D(filters, 3, strides=strides, padding="same", use_bias=False)
        self.bn1 = layers.BatchNormalization(epsilon=1e-5)
        self.conv2 = layers.Conv2D(filters, 3, padding="same", use_bias=False)
        self.bn2 = layers.BatchNormalization(epsilon=1e-5)

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        
        prev_shape = x.shape
        x = self.conv2(x)
        x = self.bn2(x)
        
        identity = inputs if self.downsample is None else self.downsample(inputs)
    
        return self.relu(x + identity)


class ResNet(keras.Model):
    
    def __init__(
        self,
        n_layers,
        num_classes=1000,
        zero_init_residual=False,
        groups=1,
        downsample=None,
        name="resnet",
        **kwargs,
    ):
        super().__init__(name=name, **kwargs)
        self.block = ResidualBlock
        
        self.in_filters = 64
        self.dilation = 1
        self.groups = 1
        
        # Defining layers
        self.relu = layers.ReLU()
        self.conv1 = layers.Conv2D(filters=64, kernel_size=7, strides=2, padding="same", use_bias=False)
        self.bn1 = layers.BatchNormalization(epsilon=1e-5)
        self.maxpool = layers.MaxPool2D(pool_size=3, strides=2, padding="same")
        self.layer1 = self._make_layer(64, n_layers[0])
        self.layer2 = self._make_layer(128, n_layers[1], strides=2)
        self.layer3 = self._make_layer(256, n_layers[2], strides=2)
        self.layer4 = self._make_layer(512, n_layers[3], strides=2)
        self.avgpool = layers.AveragePooling2D(pool_size=1)
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(num_classes)
    
        for layer in self.layers:
            if isinstance(layer, layers.Conv2D):
                layer.kernel_initializer = keras.initializers.VarianceScaling(
                    scale=2.0,
                    mode="fan_out",
                )        
    
    
    def _make_layer(self, filters, n_blocks, strides=1):
        block = self.block
        downsample = None
        previous_dilation = self.dilation
        check_singular_strides = lambda strides: (tuple(strides) != (1, 1) if isinstance(strides, Iterable) else strides != 1)
        if check_singular_strides(strides) or self.in_filters != filters:
            downsample = keras.Sequential([
                layers.Conv2D(filters, 1, strides=strides, use_bias=False),
                layers.BatchNormalization(epsilon=1e-5),
            ])
        
        layer = keras.Sequential()
        layer.add(block(filters, strides=strides, downsample=downsample))
        self.in_filters = filters
        for _ in range(1, n_blocks):
            layer.add(block(filters))
    
        return layer
    
    
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = self.flatten(x)
        return self.fc(x)


In [100]:
resnet18 = ResNet([2, 2, 2, 2], num_classes=200)

Now we come to the important part, the training. In this part we will have to include the checkpointing steps.

In [101]:
resnet18.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

resnet18.fit(train_set, epochs=5, steps_per_epoch=len(train_batches))

Epoch 1/5
  3/782 [..............................] - ETA: 35:21 - loss: 6.8403 - accuracy: 0.0000e+00 

KeyboardInterrupt: 

## Loading from checkpoint
Now that we have created a checkpointed we want to load it to check how it performs against the validation set again.

In [None]:
model = resnet18(pretrained=False, num_classes=200)
checkpoint = torch.load("checkpoint.pt")
model.load_state_dict(checkpoint["model_state_dict"])

In [None]:
loss, acc = validate(model)
print(f'''
Validation loss: {loss:.4f}
Accuracy:        {acc:.4f}''')

## Excercises
1. Write a `train_from_checkpoint` function below that given the path to a checkpoint continues training from there
2. Modify the `train_from_checkpoint` function to also save the best checkpoint so far