<a href="https://colab.research.google.com/github/mtwenzel/image-video-understanding/blob/master/Section_4_DCGAN_for_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 Fraunhofer MEVIS, 2019 The Tensorflow Authors.
Based on a notebook of TensorFlow, available at https://www.tensorflow.org/tutorials/generative/dcgan

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DC-GAN and WGAN for Segmentation

In [0]:
#@title Imports
from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
# To generate GIFs
!pip install imageio
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display
tf.__version__

## Data Preparation (the well-known CTs)

In [0]:
%%bash 
test -e tmp_slices.npz || curl -L "https://drive.google.com/uc?export=download&id=1R2-H0dhhrj6XNK7Q-MazIWGeFDOf6Zya" --output tmp_slices.npz

In [0]:
TRAINING_SLICE_COUNT = 1600 #@param {min:100, max:3300, step:100}
EPOCHS = 50 #@param {min:1, max:200, step:1}

loaded = np.load('tmp_slices.npz')

x_train = loaded['x_train'][:TRAINING_SLICE_COUNT]
y_train = loaded['y_train'][:TRAINING_SLICE_COUNT]

x_test = loaded['x_train'][TRAINING_SLICE_COUNT:]
y_test = loaded['y_train'][TRAINING_SLICE_COUNT:]

assert len(x_train) == len(y_train)
print("Found %d training and %d testing slices with shape %s" % (len(x_train),len(x_test), x_train.shape))

In [0]:
# remove the lesion labels (values 2..3)
y_train_binary = y_train.clip(0, 1)
y_test_binary = y_test.clip(0, 1)

## Wrap data
We wrap the raw numpy array into a `tf.data.Dataset`

In [0]:
BUFFER_SIZE = TRAINING_SLICE_COUNT 
BATCH_SIZE = 64

# Batch and shuffle the data
train_dataset_liver = tf.data.Dataset.from_tensor_slices( # A tuple of X and Y to shuffle and batch together.
    (np.lib.pad(x_train[...,np.newaxis], [(0,0), (20,20), (20,20), (0,0)], 'reflect'), 
    y_train_binary[...,np.newaxis]))\
    .shuffle(BUFFER_SIZE)\
    .batch(BATCH_SIZE)    

## Define new generator and discriminator

In [0]:
def make_generator_model_seg(_filters=32, filters_add=0, _kernel_size=(3,3), _padding='same', _activation='relu', _kernel_regularizer=None, _final_layer_nonlinearity='sigmoid'):
    model = tf.keras.Sequential()
    # We are indifferent about the xy size, but accept only one channel (gray value images). This has the consequence that debugging sizes gets harder.
    model.add(layers.InputLayer(input_shape=(116,116,1))) 
    
    model.add(layers.Conv2D(filters=_filters, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer, name='firstConvolutionalLayer'))
    model.add(layers.Conv2D(filters=_filters, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))
    model.add(layers.MaxPool2D())

    model.add(layers.Conv2D(filters=_filters+filters_add, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))
    model.add(layers.Conv2D(filters=_filters+filters_add, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))
    model.add(layers.MaxPool2D())

    model.add(layers.Conv2D(filters=_filters+2*filters_add, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))
    model.add(layers.Conv2D(filters=_filters+2*filters_add, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))
    model.add(layers.UpSampling2D())

    model.add(layers.Conv2D(filters=_filters+filters_add, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))
    model.add(layers.Conv2D(filters=_filters+filters_add, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))
    model.add(layers.UpSampling2D())

    model.add(layers.Conv2D(filters=_filters, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))
    model.add(layers.Conv2D(filters=_filters, kernel_size=_kernel_size, padding=_padding, activation=_activation, kernel_regularizer=_kernel_regularizer))

    model.add(layers.Conv2D(1, kernel_size=(1,1), activation=_final_layer_nonlinearity))
    return model

In [0]:
def make_discriminator_model_seg():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='valid',
                                     input_shape=[76, 76, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='valid'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

## Create the full GAN

In [0]:
generator = make_generator_model_seg(_padding = 'valid')
discriminator = make_discriminator_model_seg()

You may try the `binary_crossentropy` discriminator loss as before, or select to use a Wasserstein loss. Note that this requires to ensure Lipschitzness of discriminator. This is handled in the definition of the train steps.

In [0]:
#@title Select Loss { run: "auto" }
WASSERSTEIN_LOSS = False #@param {type:"boolean"}
#@markdown You don't need to re-execute the cell after changing your selection, it auto-executes when it has been executed once.

if WASSERSTEIN_LOSS:
  # Wasserstein loss (may work only when using a gradient penalty term or at least weight clipping)
  def discriminator_loss(real_output, fake_output):
      real_loss = -tf.reduce_mean(real_output)
      fake_loss = tf.reduce_mean(fake_output)
      total_loss = real_loss + fake_loss
      return total_loss
    
  def generator_loss(fake_output):
      return -tf.reduce_mean(fake_output)

else:
  # This method returns a helper function to compute cross entropy loss
  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  cross_entropy_prob = tf.keras.losses.BinaryCrossentropy(from_logits=False)

  # The discriminator puts out a value from [0,1] for the batch of real and fake segmentations. This is as before.
  def discriminator_loss(real_output, fake_output):
      real_loss = cross_entropy(tf.ones_like(real_output), real_output)
      fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
      total_loss = real_loss + fake_loss
      return total_loss, real_loss, fake_loss
    
  # The loss for the generator still only sees if the discriminator can tell its generations from real input.
  def generator_loss(fake_output):
      return cross_entropy(tf.zeros_like(fake_output), fake_output)



In [0]:
generator_optimizer     = tf.keras.optimizers.Adam(1e-5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-6)

In [0]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Here is a image output function.

In [0]:
def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0], cmap='gray') # Remove the scaling/offset
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

# Draw a random sample of images from the test set for testing
example_test_slices = np.random.randint(0,len(x_test),16)
seed = x_test[example_test_slices]
seed = np.lib.pad(seed[...,np.newaxis], [(0,0), (20,20), (20,20), (0,0)], 'reflect')
print(seed.shape)

Set up tensorboard, delete old logs, and create summary writers for train and test like shown [here](https://www.tensorflow.org/tensorboard/get_started)

In [0]:
import datetime
# Load the TensorBoard notebook extension
%load_ext tensorboard
# Clear any logs from previous runs
!rm -rf ./logs/ 
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

## Train the model

Formerly, we only wanted to learn images from random number, this time we replace the random input with an image, and the desired output with the segmentation mask.

Recipe:
1. Replace `noise` with CT image (first index in tuple)
1. `generated_images` are the output of the AE
1. Submit the correct images/masks to the loss calculation
1. Decouple generator and discriminator updates


In [0]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled", speeding it up a lot. It does not work with inline weight clipping, though.

# The parameter "images" is expected to hold images and masks, therefore only submit the correct index.

#@tf.function
def train_step_seg(images, epoch, num_gen_updates, wasserstein_loss):

    # Train discriminator
    with tf.GradientTape() as disc_tape:
      generated_images = generator(images[0], training=True)
      fake_output = discriminator(generated_images, training=True)
      real_output = discriminator(images[1], training=True)
      disc_loss, d_loss_real, d_loss_fake = discriminator_loss(real_output, fake_output)

      # Write losses into tensorboard log
      with train_summary_writer.as_default():
          tf.summary.scalar('disc_loss', disc_loss, step=epoch)
          tf.summary.scalar('d_loss_real', d_loss_real, step=epoch)
          tf.summary.scalar('d_loss_fake', d_loss_fake, step=epoch)

    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    # Gradient penalty or clipping would go here. Below a simple weight clipping. 
    # Note that @tf.function (compiling) is impossible, if you want to modify weights this way. You would need to make it a tensor operation to compile it into the graph.
    if wasserstein_loss:
      d_weights = discriminator.get_weights()
      clipped_d_weights = [tf.clip_by_value(w, clip_value_min=0., clip_value_max=1.) for w in d_weights]
      discriminator.set_weights(clipped_d_weights)
        
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Train Generator
    for gen_train_step in range(num_gen_updates):
      with tf.GradientTape() as gen_tape: 
        generated_images = generator(images[0], training=True)
        fake_output = discriminator(generated_images, training=True)
        gen_loss = generator_loss(fake_output)
        # Write losses into tensorboard log
        with train_summary_writer.as_default():
            tf.summary.scalar('gen_loss', gen_loss, step=num_gen_updates*epoch+gen_train_step)

      gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
      generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))


def train_seg(dataset, epochs, num_gen_updates=1, wasserstein_loss=False):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step_seg(image_batch, epoch, num_gen_updates, wasserstein_loss)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Set up Tensorboard connection and call inline

In [0]:
# Start inline TB
%tensorboard --logdir logs/gradient_tape

In [0]:
train_seg(train_dataset_liver, epochs=25, num_gen_updates=5, wasserstein_loss=WASSERSTEIN_LOSS)

## Next Steps
You may try to make the model train better. 

1. Adjust the learning rates.
1. Train the generator more often than the discriminator.
1. Pooling and upsampling are known to hamper performance. 
   
   Replace with strided convolutions, and with Upconvolutions.
1. Next, BatchNorm should be employed in the Generator.
1. The cross entropy loss is not optimal. 

  Anything that tells a real divergence/distance would be better, leading to Wasserstein GANs.

  This requires ensuring Lipschitz property of the discriminator, e.g. by weight clipping or gradient penalty, and adjusting the loss accordingly. 
