# Profiling with TensorFlow
In this notebook we will go through Profiling your training with TensorFlow and TensorBoard.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture which is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

In [None]:
from typing import Iterable

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Input, layers, Sequential
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.data import Dataset

In [None]:
class ResidualBlock(layers.Layer):
    
    def __init__(self, filters, strides=1, downsample=None):
        super().__init__()
        self.downsample = downsample
        
        self.relu = layers.ReLU()
        self.conv1 = layers.Conv2D(filters, 3, strides=strides, padding="same", use_bias=False)
        self.bn1 = layers.BatchNormalization(epsilon=1e-5)
        self.conv2 = layers.Conv2D(filters, 3, padding="same", use_bias=False)
        self.bn2 = layers.BatchNormalization(epsilon=1e-5)

    @tf.function
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        
        prev_shape = x.shape
        x = self.conv2(x)
        x = self.bn2(x)
        
        identity = inputs if self.downsample is None else self.downsample(inputs)
    
        return self.relu(x + identity)


class ResNet(keras.Model):
    
    def __init__(
        self,
        n_layers,
        num_classes=1000,
        zero_init_residual=False,
        groups=1,
        downsample=None,
        name="resnet",
        **kwargs,
    ):
        super().__init__(name=name, **kwargs)
        self.block = ResidualBlock
        
        self.in_filters = 64
        self.dilation = 1
        self.groups = 1
        
        # Defining layers
        self.relu = layers.ReLU()
        self.conv1 = layers.Conv2D(filters=64, kernel_size=7, strides=2, padding="same", use_bias=False)
        self.bn1 = layers.BatchNormalization(epsilon=1e-5)
        self.maxpool = layers.MaxPool2D(pool_size=3, strides=2, padding="same")
        self.layer1 = self._make_layer(64, n_layers[0])
        self.layer2 = self._make_layer(128, n_layers[1], strides=2)
        self.layer3 = self._make_layer(256, n_layers[2], strides=2)
        self.layer4 = self._make_layer(512, n_layers[3], strides=2)
        self.avgpool = layers.AveragePooling2D(pool_size=1)
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(num_classes)
    
        for layer in self.layers:
            if isinstance(layer, layers.Conv2D):
                layer.kernel_initializer = keras.initializers.VarianceScaling(
                    scale=2.0,
                    mode="fan_out",
                )        
    
    
    def _make_layer(self, filters, n_blocks, strides=1):
        block = self.block
        downsample = None
        previous_dilation = self.dilation
        check_singular_strides = lambda strides: (tuple(strides) != (1, 1) if isinstance(strides, Iterable) else strides != 1)
        if check_singular_strides(strides) or self.in_filters != filters:
            downsample = keras.Sequential([
                layers.Conv2D(filters, 1, strides=strides, use_bias=False),
                layers.BatchNormalization(epsilon=1e-5),
            ])
        
        layer = keras.Sequential()
        layer.add(block(filters, strides=strides, downsample=downsample))
        self.in_filters = filters
        for _ in range(1, n_blocks):
            layer.add(block(filters))
    
        return layer
    
    @tf.function
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = self.flatten(x)
        return self.fc(x)


In [None]:
resnet18 = ResNet([2, 2, 2, 2], num_classes=200)

In [None]:
# Preparing datasets
from functools import partial

import tensorflow as tf

from tf_dataset import (
    tiny_imagenet_generator,
    tiny_imagenet_signature,
    tiny_imagenet_train_size,
    tiny_imagenet_val_size,
)


n_epochs = 2
batch_size = 512

train_dataset = tf.data.Dataset.from_generator(
    generator=partial(tiny_imagenet_generator, split='train', shuffle=True),
    output_signature=tiny_imagenet_signature,
).repeat(n_epochs).batch(batch_size)
val_dataset = tf.data.Dataset.from_generator(
    generator=partial(tiny_imagenet_generator, split='val', shuffle=True),
    output_signature=tiny_imagenet_signature,
).repeat(n_epochs).batch(batch_size)

## Profiling training

Now we come to the important part, the profiling. In this part we will have to include the checkpointing steps.

In [None]:
# Profiling is done via callback
profiling_callback = tf.keras.callbacks.TensorBoard(
    log_dir='logs/base-tf',
    histogram_freq=1,
    profile_batch="15,25",
)

# Compile model as usual
resnet18.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

In [None]:
# Remember to add profiling callback
resnet18.fit(
    train_dataset,
    epochs=n_epochs,
    steps_per_epoch=(tiny_imagenet_train_size // batch_size),
    callbacks=[profiling_callback],
    validation_data=val_dataset,
    validation_steps=(tiny_imagenet_val_size // batch_size),
    verbose=1,
)

## Excercises
1. Look at the profiling results in tensorboard. To do this, follow the instructions in README.md
2. Try to follow the Performance Recomendation and try again by modifying the code below

In [None]:
# Profiling is done via callback
profiling_callback = tf.keras.callbacks.TensorBoard(
    log_dir='logs/improved-tf',
    histogram_freq=1,
    profile_batch="15,25",
)

# Compile model as usual
resnet18.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

# Remember to add profiling callback
resnet18.fit(
    train_dataset,
    epochs=2,
    steps_per_epoch=(tiny_imagenet_train_size // batch_size),
    callbacks=[profiling_callback],
    validation_data=val_dataset,
    validation_steps=(tiny_imagenet_val_size // batch_size),
    verbose=1,
)