# Checkpointing with TensorFlow
In this notebook we will go through checkpointing your model with TensorFlow.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture which is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset. But you can read up more about the dataloading information in task `3_loading_data`.

In [None]:
from typing import Iterable

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Input, layers, Sequential
from tensorflow.keras.callbacks import TensorBoard

In [None]:
class ResidualBlock(layers.Layer):
    
    def __init__(self, filters, strides=1, downsample=None):
        super().__init__()
        self.filters = filters
        self.strides = strides
        self.downsample = downsample
        
        self.relu = layers.ReLU(name='relu')
        self.conv1 = layers.Conv2D(filters, 3, strides=strides, padding="same", use_bias=False, name='conv1')
        self.bn1 = layers.BatchNormalization(epsilon=1e-5, name='bn1')
        self.conv2 = layers.Conv2D(filters, 3, padding="same", use_bias=False, name='conv2')
        self.bn2 = layers.BatchNormalization(epsilon=1e-5, name='bn2')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        
        prev_shape = x.shape
        x = self.conv2(x)
        x = self.bn2(x)
        
        identity = inputs if self.downsample is None else self.downsample(inputs)
    
        return self.relu(x + identity)


class ResNet(keras.Model):
    
    def __init__(
        self,
        n_layers,
        num_classes=1000,
        zero_init_residual=False,
        groups=1,
        downsample=None,
        name="resnet",
        **kwargs,
    ):
        super().__init__(name=name, **kwargs)
        self.block = ResidualBlock
        
        self.in_filters = 64
        self.dilation = 1
        self.groups = 1
        
        # Defining layers
        self.relu = layers.ReLU(name='relu')
        self.conv1 = layers.Conv2D(filters=64, kernel_size=7, strides=2, padding="same", use_bias=False, name='conv1')
        self.bn1 = layers.BatchNormalization(epsilon=1e-5, name='bn1')
        self.maxpool = layers.MaxPool2D(pool_size=3, strides=2, padding="same", name='maxpool')
        self.layer1 = self._make_layer(64, n_layers[0], name='layer1')
        self.layer2 = self._make_layer(128, n_layers[1], strides=2, name='layer2')
        self.layer3 = self._make_layer(256, n_layers[2], strides=2, name='layer3')
        self.layer4 = self._make_layer(512, n_layers[3], strides=2, name='layer4')
        self.avgpool = layers.AveragePooling2D(pool_size=1, name='avgpool')
        self.flatten = layers.Flatten(name='flatten')
        self.fc = layers.Dense(num_classes, name='fc')
    
        for layer in self.layers:
            if isinstance(layer, layers.Conv2D):
                layer.kernel_initializer = keras.initializers.VarianceScaling(
                    scale=2.0,
                    mode="fan_out",
                )        
    
    
    def _make_layer(self, filters, n_blocks, strides=1, **kwargs):
        block = self.block
        downsample = None
        previous_dilation = self.dilation
        check_singular_strides = lambda strides: (tuple(strides) != (1, 1) if isinstance(strides, Iterable) else strides != 1)
        if check_singular_strides(strides) or self.in_filters != filters:
            downsample = keras.Sequential(
                [
                    layers.Conv2D(filters, 1, strides=strides, use_bias=False),
                    layers.BatchNormalization(epsilon=1e-5),
                ],
            )
        
        layer = keras.Sequential(**kwargs)
        layer.add(block(filters, strides=strides, downsample=downsample))
        self.in_filters = filters
        for _ in range(1, n_blocks):
            layer.add(block(filters))
    
        return layer
    
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = self.flatten(x)
        return self.fc(x)


In [None]:
resnet18 = ResNet([2, 2, 2, 2], num_classes=200)

In [None]:
from functools import partial

import tensorflow as tf

from tf_dataset import (
    tiny_imagenet_generator,
    tiny_imagenet_signature,
    tiny_imagenet_train_size,
    tiny_imagenet_val_size,
)


n_epochs = 1
batch_size = 512

train_dataset = tf.data.Dataset.from_generator(
    generator=partial(tiny_imagenet_generator, split='train', shuffle=True),
    output_signature=tiny_imagenet_signature,
).repeat(n_epochs).batch(batch_size)
val_dataset = tf.data.Dataset.from_generator(
    generator=partial(tiny_imagenet_generator, split='val', shuffle=True),
    output_signature=tiny_imagenet_signature,
).repeat(n_epochs).batch(batch_size)

## Training with checkpoints
Now we come to the important part, the training. In this part we will have to include the checkpointing steps.

In [None]:
# Checkpointing is done via callback
checkpoint_path = "checkpoints-tf/cp-{epoch:04d}.ckpt"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='val_loss',
    verbose=1,
    save_best_only=False,
    save_weights_only=True,  # will not save entire model
    mode='auto',
    save_freq='epoch',
    options=None,
)

# Compile model as usual
resnet18.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
);

In [None]:
# Remember to add checkpoint callback
resnet18.fit(
    train_dataset,
    epochs=n_epochs,
    steps_per_epoch=(tiny_imagenet_train_size // batch_size),
    callbacks=[checkpoint_callback],
    validation_data=val_dataset,
    validation_steps=(tiny_imagenet_val_size // batch_size),
    verbose=1,
);

Notice from the above run (using 8 epochs) that we get the expected single checkpoint per epoch.

(As a side note, the results can be significantly improved if trained from a pretrained ResNet that is available from torchvision, but converting weights from PyTorch is a bit out of scope for this tutorial.)

In this example we decided to only save weight during checkpointing but we can also save the entire model. Here we do it with the trained model in the SavedModel format (instead of hdf5 which is the other alternative).

In [None]:
resnet18.save("model-tf")

Now we can compare the different directory structures of checkpointing and saving the model separetely.

In [None]:
%%bash
tree checkpoints-tf

tree model-tf

Note, that in addition to the saved models we also get meta data.

## Loading from checkpoint
Now that we have created a checkpointed we want to load it to and I've also added a check to see that the loading went as planned.

In [None]:
ckpt_model = ResNet([2, 2, 2, 2], num_classes=200)
ckpt_model.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
latest_ckpt = tf.train.latest_checkpoint("checkpoints-tf")
ckpt_model.load_weights(latest_ckpt).expect_partial()

loaded_model = tf.keras.models.load_model('model-tf')

In [None]:
for x, y in val_dataset:
    y_saved  = resnet18(x)
    y_ckpt   = ckpt_model(x)
    y_loaded = loaded_model(x)
    
    # Check that models are reproduced (atleast w.r.t. relative tolerance)
    tf.debugging.assert_near(y_ckpt, y_loaded, atol=1e-3)
    tf.debugging.assert_near(y_saved, y_ckpt)
    break

## Excercises
1. Create a cell below that continues training from the latest checkpoint
2. Modify the training to only save the best model so far