 ## Import necessary libraries

In [None]:
import tensorflow as tf
import os
import pathlib
import time
import datetime
import tqdm
import scipy.io
from matplotlib import pyplot as plt
from IPython import display

# specify the gpu to be utilized
os.environ["CUDA_VISIBLE_DEVICES"]="0" 

## Load the training and test data sets

The data for each set consists of 909 images of 3D aortic anatomy, which were either obtained using CT angiography or 4D flow MRI. The peak systolic velocity for each set was also measured using 4D flow MRI. The anatomy and flow data have been pre-processed and are each sized at 192 x 64 x 64 pixels, representing the height, width, and slice dimensions, respectively.

As described in the [pix2pix paper](https://arxiv.org/abs/1611.07004){:.external}, apply random jittering and mirroring to preprocess the training set.

Define several functions that:

1. Resize each `192 x 64` image to a larger height and width—`222 x 94`.
2. Randomly crop it back to `192 x 64`.
3. Randomly flip the image horizontally i.e. left to right (random mirroring).

It should be noted that in this implementation we don't normalize the images to the `[-1, 1]` range, because we need to predict the exact value of aortic hemodynamics.

In [None]:
# specify the data size,  batch siz
BATCH_SIZE = 1 # batch size of 1 has given the best results
DATA_HEIGHT = 192
DATA_WIDTH = 64
DATA_DEPTH = 64

# resize both anatomy and flow to a larger size to crop them later
def resize(anatomy,flow,height,width):
    anatomy = tf.image.resize(anatomy,[height,width],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    flow = tf.image.resize(flow,[height,width],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return anatomy,flow

# randomly crop the enlarged data to its original size
def random_crop(anatomy,flow):
    stacked_data = tf.stack([tf.squeeze(anatomy),tf.squeeze(flow)],axis=0)
    cropped_data = tf.image.random_crop(stacked_data,
                                        size=[2,DATA_HEIGHT,DATA_WIDTH,DATA_DEPTH])
    return tf.expand_dims(cropped_data[0],axis=-1),tf.expand_dims(cropped_data[1],axis=-1)

# random jitter
@tf.function()
def random_jitter(anatomy,flow):
    anatomy,flow = resize(anatomy,flow,DATA_HEIGHT+30,DATA_WIDTH+30) # resize
    anatomy,flow = random_crop(anatomy,flow) # crop back to the original size
    if tf.random.uniform(()) > 0.5:
        anatomy = tf.image.flip_left_right(anatomy)
        flow = tf.image.flip_left_right(flow)
    return anatomy,flow

# crop the enlarged images back to original size around the center
@tf.function
def center_crop(image, size):
    if not isinstance(size, (tuple, list)):
        size = [size, ize]
    offset_height = (tf.shape(image)[-3]-size[0])//2
    offset_width = (tf.shape(image)[-2]-size[1])//2
    return tf.image.crop_to_bounding_box(image,offset_height,offset_width,size[0],size[1])

# the parser function for reading the training data set
def parser_train(tfrecord):
    feature = tf.io.parse_single_example(tfrecord,{'A': tf.io.FixedLenFeature(shape=[], dtype=tf.string),
              'B' : tf.io.FixedLenFeature(shape=[], dtype=tf.string),
              'height' : tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
              'width'  : tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
              'depth'  : tf.io.FixedLenFeature(shape=[], dtype=tf.int64)})
    height = tf.cast(feature["height"], tf.int32)
    width  = tf.cast(feature["width"], tf.int32)
    depth  = tf.cast(feature["depth"], tf.int32)
    A = tf.io.decode_raw(feature['A'], tf.float32) 
    A = tf.reshape(A, [height, width, depth])
    A = center_crop(A, [DATA_HEIGHT,DATA_WIDTH])
    B = tf.io.decode_raw(feature['B'], tf.float32) 
    B = tf.reshape(B, [height, width, depth])
    B = center_crop(B, [DATA_HEIGHT,DATA_WIDTH])
    return random_jitter(A, B)

# the parser function for reading the test dataset.Importantly, the test data will not undergo the random jittering
def parser_test(tfrecord):
    feature = tf.io.parse_single_example(tfrecord,{'A': tf.io.FixedLenFeature(shape=[], dtype=tf.string),
              'B' : tf.io.FixedLenFeature(shape=[], dtype=tf.string),
              'height' : tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
              'width'  : tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
              'depth'  : tf.io.FixedLenFeature(shape=[], dtype=tf.int64)})
    height = tf.cast(feature["height"], tf.int32)
    width  = tf.cast(feature["width"], tf.int32)
    depth  = tf.cast(feature["depth"], tf.int32)
    A = tf.io.decode_raw(feature['A'], tf.float32) 
    A = tf.reshape(A, [height, width, depth])
    A = center_crop(A, [DATA_HEIGHT,DATA_WIDTH])
    B = tf.io.decode_raw(feature['B'], tf.float32) 
    B = tf.reshape(B, [height, width, depth])
    B = center_crop(B, [DATA_HEIGHT,DATA_WIDTH])
    return tf.expand_dims(A,axis=-1),tf.expand_dims(B,axis=-1)


tfrecord_path = 'anatomay2flow_train_pix2pix.tfrecords' # specify the path to the training set
dataset_train = tf.data.TFRecordDataset(tfrecord_path)
DATA_SIZE = len(list(dataset_train)) # return the size of the training set
BUFFER_SIZE = DATA_SIZE
dataset_train = dataset_train.map(map_func=parser_train,num_parallel_calls=tf.data.AUTOTUNE)
dataset_train = dataset_train.shuffle(buffer_size=BUFFER_SIZE)
dataset_train = dataset_train.batch(BATCH_SIZE)

# create a set of test samples
tfrecord_path = 'anatomay2flow_test_pix2pix.tfrecords'
dataset_test = tf.data.TFRecordDataset(tfrecord_path)
dataset_test = dataset_test.map(map_func=parser_test,num_parallel_calls=tf.data.AUTOTUNE)
dataset_test = dataset_test.batch(BATCH_SIZE)
dataset_test = iter(dataset_test);

# save ten samples in the test to monitor the model performance as it is being trained
Asample = list()
Bsample = list()
ii = 1
for a,b in dataset_test:
    Asample.append(a)
    Bsample.append(b)
    ii+=1
    if ii>10:
        print(a.shape)
        break

## Build the generator
Build the generator

The generator is a modified U-Net{:.external}. A U-Net consists of an encoder (downsampler) and decoder (upsampler).

Each block in the encoder is: Convolution -> Batch normalization -> Leaky ReLU

Each block in the decoder is: Transposed convolution -> Batch normalization -> Dropout (applied to the first 3 blocks) -> ReLU

There are skip connections between the encoder and decoder (as in the U-Net).

In [None]:
# define the generator downsampler
def downsample(filters,size,apply_batchnorm=True):
    initializer=tf.random_normal_initializer(0,0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv3D(filters,size,strides=2,padding='same',
                                    kernel_initializer=initializer,use_bias=False))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU())
    return result

# verify the shape of the downsampler output
down_model = downsample(3,4)
A = Asample[0]
print("Shape of the inputs of a downsampler with 3 filters: ",A.shape)
down_result = down_model(A)
print("Shape of the output of a downsampler with 3 filters: ",down_result.shape)

# define the generator upsampler
def upsample(filters,size,apply_dropout=False):
    initializer = tf.random_normal_initializer(0,0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv3DTranspose(filters,size,strides=2,padding='same',
                                       kernel_initializer=initializer,use_bias=False))
    result.add(tf.keras.layers.BatchNormalization())
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    result.add(tf.keras.layers.ReLU())
    return result

# verify the shape of the upsampler output
up_model = upsample(3,4)
up_result = up_model(down_result)
print("Shape of the output of a up-sampler with 3 filters: ",up_result.shape)

# defined the generator using the pre-defined down- and up-samplers
def Generator(OUTPUT_CHANNELS=1):
    inputs = tf.keras.layers.Input(shape=[DATA_HEIGHT,DATA_WIDTH,DATA_DEPTH,1])
    down_stack=[downsample(64,4,apply_batchnorm=False), #(1,96,32,32,64)
               downsample(128,4), #(1,48,16,16,128)
               downsample(256,4), #(1,24,8,8,256)
               downsample(512,4), #(1,12,4,4,512)
               downsample(512,4), #(1,6,2,2,512)
               downsample(512,4), #(1,3,1,1,512)
               ]
    up_stack=[upsample(512,4,apply_dropout=True),
             upsample(512,4,apply_dropout=True),
             upsample(256,4), 
             upsample(128,4),
             upsample(64,4), 
             ]
    initializer=tf.random_normal_initializer(0,0.02)
    last = tf.keras.layers.Conv3DTranspose(OUTPUT_CHANNELS,4,strides=2,padding='same',
                                          kernel_initializer=initializer,activation='linear')
                                          
    x = inputs
    skips=[]
    for down in down_stack:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])
    for up,skip in zip(up_stack,skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x,skip])
    x = last(x)
    return tf.keras.Model(inputs=inputs,outputs=x)

# generator model summary
generator = Generator()
print("This is the generator model summary")
generator.summary()
print("This a the plot of the generator model")

# save a plot of the generator model
tf.keras.utils.plot_model(generator,show_shapes=True,dpi=64)
gen_output = generator(Asample[0],training=False)

# draw an example of how the output of generator would look like
print("This is an example of how the output of generator would look like:")
plt.imshow(tf.reduce_mean(tf.squeeze(gen_output),axis=-1),cmap='jet')
plt.colorbar()
plt.show()

## Build the discriminator

The discriminator is a convolutional PatchGAN classifier—it tries to classify if each image _patch_ is real or not real.

- Each block in the discriminator is: Convolution -> Batch normalization -> Leaky ReLU.
- The discriminator receives 2 inputs: 
    - The input image and the target image, which it should classify as real.
    - The input image and the generated image (the output of the generator), which it should classify as fake.
    - Use `tf.concat([inp, tar], axis=-1)` to concatenate these 2 inputs together.

In [None]:
# build the discriminator model
def Discriminator():
    initializer = tf.random_normal_initializer(0.,0.02)
    anatomy = tf.keras.layers.Input(shape=([DATA_HEIGHT,DATA_WIDTH,DATA_DEPTH,1]),
                                   name='input_image')
    flow = tf.keras.layers.Input(shape=([DATA_HEIGHT,DATA_WIDTH,DATA_DEPTH,1]),
                                name='target_image')
    x = tf.keras.layers.concatenate([anatomy,flow]) 
    down1 = downsample(64,4,False)(x)  
    down2 = downsample(128,4)(down1) 
    down3 = downsample(256,4)(down2)  
    zero_pad1 = tf.keras.layers.ZeroPadding3D()(down3) 
    conv = tf.keras.layers.Conv3D(512,4,strides=1,
                                 kernel_initializer=initializer,
                                 use_bias=False)(zero_pad1) 
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    zero_pad2 = tf.keras.layers.ZeroPadding3D()(leaky_relu) 
    last = tf.keras.layers.Conv3D(1,4,strides=1,
                                 kernel_initializer=initializer)(zero_pad2)
    return tf.keras.Model(inputs=[anatomy,flow],outputs=last)

# visualize the discriminator model
discriminator = Discriminator()
discriminator.summary()
tf.keras.utils.plot_model(discriminator,show_shapes=True,dpi=64)

# visualize an example output of the discriminator model
disc_out = discriminator([Asample[0],gen_output],training=False)
plt.imshow(tf.reduce_mean(tf.squeeze(disc_out),axis=-1),cmap='jet')
plt.colorbar()
plt.show()

## Define the generator and discriminaotr losses

cGANs learn a structured loss that penalizes a possible structure that differs from the network output and the target image.

- The generator loss is a sigmoid cross-entropy loss of the generated images and an **array of ones**.
- The pix2pix paper also mentions the L1 loss, which is a MAE (mean absolute error) between the generated image and the target image.
- This allows the generated image to become structurally similar to the target image.
- The formula to calculate the total generator loss is `gan_loss + LAMBDA * l1_loss`, where `LAMBDA = 100`. 


- The `discriminator_loss` function takes 2 inputs: **real images** and **generated images**.
- `real_loss` is a sigmoid cross-entropy loss of the **real images** and an **array of ones(since these are the real images)**.
- `generated_loss` is a sigmoid cross-entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**.
- The `total_loss` is the sum of `real_loss` and `generated_loss`.

In [None]:
# definition of the generator loss
LAMBDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output,gen_ouptput,target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output),disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target-gen_output))
    total_gan_loss = gan_loss+(LAMBDA*l1_loss)
    return total_gan_loss,gan_loss,l1_loss

# define the discriminator loss
def discriminator_loss(disc_real_output,disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output),disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output),disc_generated_output)
    total_disc_loss = real_loss+generated_loss
    return total_disc_loss

## Define the generator and discriminator optimizers and a checkpoint-saver

In [None]:
# define the optimizers
generator_optimizer = tf.keras.optimizers.Adam(2e-4,beta_1=0.5)
discriminator_optimizer=tf.keras.optimizers.Adam(2e-4,beta_1=0.5)

output_dir = './pix2pix_output' # the directory to save the model outputs
if not tf.io.gfile.isdir(output_dir):
    tf.io.gfile.mkdir(output_dir) # if the output directory does not exist, creates it
    
# define the checkpoint directory
checkpoint_dir = tf.io.gfile.join(output_dir,'training_checkpoints') # where to save model ch
checkpoint_prefix = tf.io.gfile.join(checkpoint_dir,"ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                discrimnator_optimizer=discriminator_optimizer,
                                generator=generator,
                                discriminator=discriminator)

## Generate images

Write a function to plot some images during training.

- Pass images from the test set to the generator.
- The generator will then translate the input image into the output.

In [None]:
# generate images over the test dataset
def generate_images(model,test_input,tar):
    prediction = model(test_input,training=True)
    plt.figure(figsize=(15,15))
    display_list = [test_input[0],tar[0],prediction[0]]
    title=['input_image','target_image','predicted image']
    for i in range(3):
        plt.subplot(1,3,i+1)
        plt.title(title[i])
        plt.imshow(tf.reduce_max(tf.squeeze(display_list[i]),axis=-1),cmap='jet',clim=[0,2])
        plt.axis('off')
    plt.show()
# visualize example images
generate_images(generator,Asample[0],Bsample[0])

## Training

- For each example input generates an output.
- The discriminator receives the `input_image` and the generated image as the first input. The second input is the `input_image` and the `target_image`.
- Next, calculate the generator and the discriminator loss.
- Then, calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.

In [4]:
@tf.function
def train_step(anatomy,flow,step):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(anatomy,training=True)
        disc_real_output = discriminator([anatomy,flow],training=True)
        disc_generated_output = discriminator([anatomy,gen_output],training=True)
        gen_total_loss,gen_gan_loss,gen_l1_loss = generator_loss(disc_generated_output,
                                                                gen_output,
                                                                flow)
        disc_loss = discriminator_loss(disc_real_output,disc_generated_output)
    generator_gradients = gen_tape.gradient(gen_total_loss,
                                           generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                                discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))

epochs = 100 # specifiy number of epochs to train the network

sample_dir = tf.io.gfile.join(output_dir,'samples_training') # directory to save example .mat files
if not tf.io.gfile.isdir(sample_dir):
    tf.io.gfile.mkdir(sample_dir)
    
# Restoring the latest checkpoint in checkpoint_dir
checkpoint_address = tf.train.latest_checkpoint(checkpoint_dir) # returns the address of last checkpoint
epochs_so_far = 0 
if checkpoint_address: # verifies if the checkpoint exists
    checkpoint.restore(checkpoint_address) # restores to the last checkpoint
    hyphen = checkpoint_address.index('-') # finds the location of hyphen in the checkpoint address
    epochs_so_far = int(checkpoint_address[hyphen+1:]) # finds the number trained epochs
    print("Restored the model from epoch {}".format(epochs_so_far))

for epoch in tqdm.trange(epochs_so_far+1,epochs+epochs_so_far+1,desc="Outer Epoch",total=epochs):
    for anatomy,flow in tqdm.tqdm(dataset_train,desc="Inner Epoch",total=DATA_SIZE):
        train_step(anatomy,flow)
    checkpoint.save(file_prefix=checkpoint_prefix) # save a checkpoint after each epoch
    example_anatomy,example_flow = next(dataset_test) # load an example test dataset after each epoch
    generate_images(generator,example_anatomy,example_flow) # generate example images after each epoch
    for ii in range(len(Asample)):
        A = tf.reshape(Asample[ii],(1,DATA_HEIGHT,DATA_WIDTH,DATA_DEPTH,1))
        B = tf.reshape(Bsample[ii],(1,DATA_HEIGHT,DATA_WIDTH,DATA_DEPTH,1))
        A2B = generator(A,training=True)
        anatomy = tf.squeeze(A.numpy())
        flow = tf.squeeze(B.numpy())
        anatomy2flow = tf.squeeze(A2B.numpy())  
        filename1 = tf.io.gfile.join(sample_dir,'iter-%03u-%02u.mat' % (epoch,ii))
        scipy.io.savemat(filename1,{'anatomy':anatomy.numpy(),
                        'flow':flow.numpy(),
                        'anatomy2flow':anatomy2flow.numpy()})

(1, 192, 64, 64, 1)
