## Custom TF training loop

In [None]:
import tensorflow as tf
import numpy as np
import tensorflow.keras as K
import pandas as pd
import matplotlib.pyplot as plt

## Boston housing dataset

In [2]:
boston_housing=tf.keras.datasets.boston_housing.load_data(
    path="boston_housing.npz", test_split=0.2, seed=113
)

(X_train_full, y_train_full), (X_test, y_test) = boston_housing

X_train, y_train = X_train_full[:-50], y_train_full[:-50]
X_valid, y_valid = X_train_full[-50:], y_train_full[-50:]

## Model definition

In [None]:
l2_reg = tf.keras.regularizers.l2(0.05)

model = tf.keras.models.Sequential([
tf.keras.layers.Dense(30, activation='relu', 
                      kernel_initializer='he_normal',
                      kernel_regularizer=l2_reg),
tf.keras.layers.Dense(1,
                      kernel_regularizer=l2_reg)])

## Manual batching function

In [4]:
def random_batch(X, y, batch_size=32):
    idx = np.random.randint(len(X), size=batch_size)
    return X[idx], y[idx]

## Progress bar function

In [5]:
def print_status_bar(step, total, loss, metrics=None):
    metrics = " . ".join(f"{m.name}: {m.result():.4f}"
                        for m in (metrics or []))
    end = "" if step < total else "\n"
    print(f"\r{step}/{total} - " + metrics, end=end)

## Hyperparameters and training component functions

In [6]:
n_epochs = 15
batch_size = 32
n_steps = len(X_train) // batch_size

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
loss_fn = tf.keras.losses.mean_squared_error
mean_loss = tf.keras.metrics.Mean(name="mean_loss")
metrics = [tf.keras.metrics.MeanAbsoluteError()]

## Training loop

In [None]:
# Iterations over epochs, each time over the whole dataset
# or at least over as much of the dataset as decided to use,
# see steps per execution
for epoch in range(1, n_epochs + 1):
    print("Epoch: {}/{}".format(epoch, n_epochs))
    # Iteration over batches within each epoch
    for step in range(1, n_steps + 1):
        # With a custom dataset we'd have:
        # for X_batch, y_batch in train_set: ...
        X_batch, y_batch = random_batch(X_train, y_train)
        with tf.GradientTape() as tape:
            y_pred = model(X_batch, training=True)
            # Here we compute the mean over the batch,
            # loss being mse is a batch sized tensor of losses per instance
            # so reduce_mean gives a scalar
            # aux = loss_fn(y_batch, y_pred)
            # print(tf.shape(aux)) # Shows tf.Tensor([32], shape=(1,), dtype=int32)
            main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
            # print(tf.shape(main_loss)) # Shows tf.Tensor([], shape=(0,), dtype=int32) 
            # model has been defined with regularizing layers
            # which adds additional internal losses to the model, one per layer
            # print(tf.shape(model.losses)) # Shows tf.Tensor([2], shape=(1,), dtype=int32)
            # So two scalars from each layer aggregated losses must be added
            # to the total loss: this is a sum of a scalar kept in
            # array [main_loss] and rank-1 list of two regularizer losses model.losses
            # As a result we get a scalar final loss: loss
            # print(tf.shape(loss)) # Shows tf.Tensor([], shape=(0,), dtype=int32)
            loss = tf.add_n([main_loss] + model.losses)
        gradients = tape.gradient(main_loss, model.trainable_variables)
        
        # Here is the space to perform optional computations/constraints on gradients
        # before calling the apply_gradients
        # <...>
        
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        # If layer parameters constraints were introduced then here we apply them
        # post gradient application
        for variable in model.variables:
            if variable.constraint is not None:
                variable.assign(variable.constraint(variable))
        
        mean_loss(loss)
        for metric in metrics:
            metric(y_batch, y_pred)
        
        print_status_bar(step, n_steps, mean_loss, metrics)
    for metric in [mean_loss] + metrics:
        metric.reset_states()