# Putting it together:  Using TensorBoard to log training data of a subclassed model with keras metrics with a custom training loop

We will define a subclassed feed forward fully connected model and store loss and accuracy for both training and validation data to the TensorBoard. 

To do this in a clean way, we implement the keras metrics that keep track of loss and accuracy in each epoch for us as part of the model. We also define the train and test steps as methods inside the model rather than as external functions. Doing so will move us one step closer to being able to use the in-built training and evaluation methods that come with Tensorflow/Keras, that is the compile and fit methods.

To use train_step and test_step as methods of the model, we need to have the loss-function, the metrics, and the optimizer as parts of the model, which is why we define them in the init method.

Note that we need to update the metrics after each training example and reset the metrics after each epoch or before evaluating our model on the validation data set.

Also note that the metrics_list contains a mean metric for the loss, which does not take targets and predictions as arguments in its update_state method, but just a scalar. For this reason, we treat it differently from the remaining metrics.

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import math
import datetime

# in a notebook, load the tensorboard extension, not needed for scripts
%load_ext tensorboard

In [None]:
class FFN(tf.keras.Model):
    def __init__(self):
        super().__init__()
    
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.metrics_list = [
                        tf.keras.metrics.Mean(name="loss"),
                        tf.keras.metrics.CategoricalAccuracy(name="acc"),
                        tf.keras.metrics.TopKCategoricalAccuracy(3,name="top-3-acc") 
                       ]
        
        self.loss_function = tf.keras.losses.CategoricalCrossentropy(from_logits=True)   
        
        # define layers
        self.flatten = tf.keras.layers.Flatten()
        self.layer1 = tf.keras.layers.Dense(32,activation="relu")
        self.layer2 = tf.keras.layers.Dense(64, activation="relu")
        self.layer3 = tf.keras.layers.Dense(128, activation="relu")
        self.layer4 = tf.keras.layers.Dense(256, activation="relu")
        self.output_layer = tf.keras.layers.Dense(10, activation=None)
    
    def call(self, x, training=False):
        # flatten images to vectors
        x = self.flatten(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        out = self.output_layer(x)
       
        return out
    
    def reset_metrics(self):
        
        for metric in self.metrics:
            metric.reset_states()
            
    #@tf.function
    def train_step(self, data):
        
        x, targets = data
        
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            
            loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(targets,predictions)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    #@tf.function
    def test_step(self, data):

        x, targets = data
        predictions = self(x, training=False)
        loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)

        self.metrics[0].update_state(loss)
        # for accuracy metrics:
        for metric in self.metrics[1:]:
            metric.update_state(targets, predictions)

        return {m.name: m.result() for m in self.metrics}

# Preparing the training and validation data

In [None]:
ds = tfds.load("fashion_mnist", as_supervised=True)

train_ds = ds["train"]
val_ds = ds["test"]

train_ds = train_ds.map(lambda x,y: (x/255, tf.one_hot(y, 10, dtype=tf.float32)),\
                        num_parallel_calls=tf.data.AUTOTUNE).cache().shuffle(5000).batch(32).prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.map(lambda x,y: (x/255, tf.one_hot(y, 10, dtype=tf.float32)),\
                    num_parallel_calls=tf.data.AUTOTUNE).cache().shuffle(5000).batch(32).prefetch(tf.data.AUTOTUNE)

In [None]:
# instantiate the model
model = FFN()

# run model on input once so the layers are built
model(tf.keras.Input((28,28,1)));
model.summary()

# Instantiate the file-writers for the training

We store the tensorboard logs to a folder with a meaningful name (e.g. name of training run + date and time). Additionally, when running experiments, you want to save a config file that can be associated with these logs, containing all information about the architecture and hyperparameters that were used. To be extra sure, you could also make a copy of the code that was used. Not knowing which settings lead to which results should be avoided by all means. A good tool for configurations is the [Hydra library](https://hydra.cc/) which also allows to have objects and their arguments as part of a config file (e.g. activation functions, optimizers etc.).

- We create a train writer and a validation writer

In [None]:
# Define where to save the log
config_name= "Run-42"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_log_path = f"logs/{config_name}/{current_time}/train"
val_log_path = f"logs/{config_name}/{current_time}/val"

# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)

# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)

# Writing the training loop

Note that you need to re-run the above cell (and hence update the time-stamp) if you don't want to over-write the data of the previous training-run.

If you use keras metrics, do not forget to reset the states between train and validation and between epochs.
We use metric.update_states(...) to update a metric. This usually means we update the running average with the new value. There also exist keras metrics that can also compute scores such as CategoricalAccuracy, TopKCategoricalAccuracy.

We use TQDM to see the progress of each epoch and the estimate of how much time it will take.

Instead of looking at the printed losses and accuracies, we can look at the TensorBoard plots which will be updated after every epoch. This requires us to open and load the tensorboard *before* starting the training or to open the tensorboard from a terminal.

In [None]:
import pprint
import tqdm

def training_loop(model, train_ds, val_ds, epochs, train_summary_writer, val_summary_writer):
    for epoch in range(epochs):
        print(f"Epoch {epoch}:")
        
        # Training:
        
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)
            
            # logging the validation metrics to the log file which is used by tensorboard
            with train_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)

        # print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics (requires a reset_metrics method in the model)
        model.reset_metrics()    
        
        # Validation:
        for data in val_ds:
            metrics = model.test_step(data)
        
            # logging the validation metrics to the log file which is used by tensorboard
            with val_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
                    
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics
        model.reset_metrics()
        print("\n")

In [None]:
%tensorboard --logdir logs/

In [None]:
# run the training loop 
training_loop(model=model, 
                train_ds=train_ds, 
                val_ds=val_ds, 
                epochs=15, 
                train_summary_writer=train_summary_writer, 
                val_summary_writer=val_summary_writer)

# Saving and loading a subclassed model

Because training deep neural networks can take multiple days, weeks or even months, we want to save checkpoints in between. This is especially useful if you use Google Colab and you save the model directly to your Google Drive folder. That way you don't lose any progress if your runtime gets closed.

In [None]:
# save the model with a meaningful name
model.save_weights(f"saved_model_{config_name}", save_format="tf")

# load the model:
# instantiate a new model from our CNN class
loaded_model = FFN()

# build the model
inp= tf.keras.Input((28,28,1))
loaded_model(inp)

# load the model weights to continue training. 
loaded_model.load_weights(f"saved_model_{config_name}");

# continue training (but: optimizer state is lost)

# run the training loop 
training_loop(model=loaded_model, 
                train_ds=train_ds, 
                val_ds=val_ds, 
                epochs=10, 
                train_summary_writer=train_summary_writer, 
                val_summary_writer=val_summary_writer)