# Using TensorBoard with custom training schedule

The API for `tf.keras.callbacks.TensorBoard` is developed for use with the `Model.fit` method. However, since we have a complicated training step for the GAN, we have to use the `TensorBoard` callback as a normal function, and pass it the values of the losses and images that we want to record.

Let's take the example of the MNIST training data set

In [2]:
import h5py
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from pathlib import Path
import os
import began
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt

In [11]:
# Project directory
PROJ_DIR = Path("/home/bthorne/projects/gan/began")
# Model directory
MODEL_PATH = PROJ_DIR / "model" / "mnist_dcgan_NTRAIN5000.h5"

# Network architecture
DEPTH = 16
IMG_DIM = 28
CHANNELS = 1
KERNELS = [5, 5, 5]
STRIDES = [2, 2, 2]
FILTERS = [DEPTH * 2 ** i for i in range(len(KERNELS))]
LATENT_DIM = 32

# Derived parameters
SHAPE = (IMG_DIM, IMG_DIM, CHANNELS)

# Training parameters
TRAIN_STEPS = 5000
BATCH_SIZE = 32

# Build inidividual and joint models.
DIS = began.dcgan.build_discriminator(FILTERS, KERNELS, STRIDES, SHAPE)
GEN = began.dcgan.build_generator(DIS, FILTERS, KERNELS, STRIDES, LATENT_DIM, SHAPE)
ADV = began.dcgan.build_adversarial_model(DIS, GEN)
print(GEN.summary())

Model: "Generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Dense_G (Dense)              (None, 1024)              33792     
_________________________________________________________________
Reshape (Reshape)            (None, 4, 4, 64)          0         
_________________________________________________________________
BNorm_G1 (BatchNormalization (None, 4, 4, 64)          256       
_________________________________________________________________
LRelu_G1 (LeakyReLU)         (None, 4, 4, 64)          0         
_________________________________________________________________
UpSample_1 (UpSampling2D)    (None, 8, 8, 64)          0         
_________________________________________________________________
Conv2D_G1 (Conv2D)           (None, 8, 8, 32)          51232     
_________________________________________________________________
BN_G2 (BatchNormalization)   (None, 8, 8, 32)          12

In [12]:
# Load raw training data
(X_TRAIN, _), (_, _) = mnist.load_data()

# Apply preprocessing to scale data
X_TRAIN = X_TRAIN[..., None] / 255. * 2. - 1.

In [None]:
# Setup logging
logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir + "/metrics")
file_writer.set_as_default()
tboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
tboard_callback.set_model(DIS)

image_lat = np.random.randn(1, LATENT_DIM)
for step in range(TRAIN_STEPS):
        tf.summary.experimental.set_step(step)
        # First train the discriminator with correct labels
        # Randomly select batch from training samples
        y_real = np.random.binomial(1, 0.99, size=[BATCH_SIZE, 1])
        y_fake = np.random.binomial(1, 0.01, size=[BATCH_SIZE, 1])
        idx = np.random.randint(0, X_TRAIN.shape[0], size=BATCH_SIZE)
        
        # Rotate each image by random integer multiples of 90 degrees. This should
        # probably be transfereed out of this code to be done in the data preparation
        # stage. 
        images_real = X_TRAIN[idx, ...]
        images_real = np.array([np.rot90(im, np.random.randint(0, 4)) for im in images_real])

        # Use `generator` to create fake images.
        noise = np.random.normal(loc=0., scale=1., size=[BATCH_SIZE, LATENT_DIM])
        images_fake = GEN.predict(noise)

        # Train the discriminator on real and fake images.
        dloss_real = DIS.train_on_batch(images_real, y_real)
        dloss_fake = DIS.train_on_batch(images_fake, y_fake)
        dloss = 0.5 * (dloss_real + dloss_fake)
        # Now train the adversarial network.
        # Create new fake images, and label as if they are from the training set.
        # Lie indicates that we are tricking the adversarial network by
        # telling it the target is valid, when in reality the discriminator
        # is being fed fake images by the generator.
        y_lie = np.ones([BATCH_SIZE, 1])
        noise = np.random.normal(loc=0., scale=1., size=[BATCH_SIZE, LATENT_DIM])
        a_loss = ADV.train_on_batch(noise, y_lie)
        tf.summary.image('random_draw', GEN.predict(image_lat))
        tf.summary.scalar('aloss', a_loss)
        tf.summary.scalar('dloss', dloss)
        #tboard_callback.on_train_batch_end(batch=step)
        if not step % 10:
            print("Step {:05d}, dloss {:.03f}".format(step, dloss))

Step 00000, dloss 0.921
Step 00010, dloss 0.214
Step 00020, dloss 0.150
Step 00030, dloss 0.495
Step 00040, dloss 0.528
Step 00050, dloss 0.367
Step 00060, dloss 0.333
Step 00070, dloss 0.245
Step 00080, dloss 0.207
Step 00090, dloss 0.156
Step 00100, dloss 0.274
Step 00110, dloss 0.161
Step 00120, dloss 0.106
Step 00130, dloss 0.111
Step 00140, dloss 0.105
Step 00150, dloss 0.103
Step 00160, dloss 0.121
Step 00170, dloss 0.084
Step 00180, dloss 0.040
Step 00190, dloss 0.085
Step 00200, dloss 0.052
Step 00210, dloss 0.116
Step 00220, dloss 0.058
Step 00230, dloss 0.149
Step 00240, dloss 0.095
Step 00250, dloss 0.079
Step 00260, dloss 0.169
Step 00270, dloss 0.086
Step 00280, dloss 0.041
Step 00290, dloss 0.130
Step 00300, dloss 0.026
Step 00310, dloss 0.084
Step 00320, dloss 0.033
Step 00330, dloss 0.035
Step 00340, dloss 0.022
Step 00350, dloss 0.019
Step 00360, dloss 0.022
Step 00370, dloss 0.092
Step 00380, dloss 0.018
Step 00390, dloss 0.009
Step 00400, dloss 0.015
Step 00410, dlos

In [None]:
plt.imshow(ADV.get_layer('Generator').predict(np.random.randn(1, 16))[0, :, :, 0])

In [None]:
trained_model.save(os.fspath(MODEL_PATH))