## CycleGAN CT Generation
# Graham Schloesser 08/13/21
# Based on code from https://keras.io/examples/generative/cyclegan/

This code was developed to generate synthetic CT images from MRI image inputs. To be able to run this code make sure the file paths are updated. 

In [None]:
pip install tensorflow-addons

In [None]:
import tensorflow as tf
import numpy as np
import nibabel as nib
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
import os
import time
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount('/content/drive')

# Adjust the following file paths
example_filename_CT and MRI are paths to the individual scans for testing. Make sure that these two scans are paired data

train_path_CT and MRI are a folder of multiple scans to train the model on. 

checkpoint path is where the model is saved


In [None]:
example_filename_CT = '/content/drive/My Drive/GAN_IM/CT_Pair/020_Ax_T1_BRAVO_Stealth_CT.nii.gz'
example_filename_MRI = '/content/drive/My Drive/GAN_IM/MR_Pair/020_Ax_T1_BRAVO_Stealth.nii.gz'
train_path_MRI = '/content/drive/My Drive/GAN_IM/MR_Pair'
train_path_CT = '/content/drive/My Drive/GAN_IM/CT_Pair'
checkpoint_path = '/content/drive/MyDrive/GAN_IM/checkpoints'

AUTOTUNE = tf.data.AUTOTUNE
OUTPUT_CHANNELS = 3
BUFFER_SIZE = 250
BATCH_SIZE = 1
IMG_WIDTH = 512
IMG_HEIGHT = 512
DEPTH = 232

autotune = tf.data.experimental.AUTOTUNE
# Weights initializer for the layers.
kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
# Gamma initializer for instance normalization.
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

input_img_size = (512, 512, 3)


In [None]:
def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH])
    return cropped_image

def normalize_MRI(image):
    image = tf.cast(image, tf.float32)
    image = (image / (tf.math.reduce_max(image))*2) - 1
    #image = tf.image.resize(image,[256,256])
    return image

def normalize_CT(image):
    image = tf.cast(image, tf.float32)
    image = image + abs(tf.math.reduce_min(image))
    image = (image / (tf.math.reduce_max(image))*2) - 1
    #image = tf.image.resize(image,[256,256])
    return image

def preprocess_MRI(image):
    #image = random_jitter(image)
    image = normalize_MRI(image)
    return image
def preprocess_CT(image):
    #image = random_jitter(image)
    image = normalize_CT(image)
    return image


def pipe_train(path,type):
    threshold = 9
    sample_limit = 0                    
    for filename in os.listdir(path):  
        if filename.endswith('.gz'):   
          sample = nib.load(path + '/' + filename).get_fdata()
          samp_shape = np.shape(sample)
          size = math.floor(samp_shape[2]/3)
          temp_array = np.zeros((size,samp_shape[0],samp_shape[1],3))
          for i in range(size):
              if type == 0:     #for MRI images
                  temp_array[i,:,:,0:2] = preprocess_MRI(sample[:,:,i*3:(i*3)+2])
              if type == 1:     #for CT images
                  temp_array[i,:,:,0:2] = preprocess_CT(sample[:,:,i*3:(i*3)+2])

          if sample_limit == 0:
              combined_array = temp_array
          else:
              combined_array = np.append(combined_array,temp_array,axis=0)
              

          sample_limit = sample_limit + 1

          if sample_limit >= threshold:
            break

    train_data = tf.data.Dataset.from_tensor_slices(combined_array)

    tf_data = train_data.map(do_nothing, num_parallel_calls=AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    return tf_data

def do_nothing(image):
    return tf.cast(image, tf.float32)

In [None]:
train_MRI = None
train_CT = None

train_MRI = pipe_train(train_path_MRI,0)
train_CT = pipe_train(train_path_CT,1)

In [None]:
_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, samples in enumerate(zip(train_MRI.take(4), train_CT.take(4))):
    MRI = (((samples[0][0]+1.0)*127.5).numpy()).astype(np.uint8)
    CT = (((samples[1][0]+1.0)*127.5).numpy()).astype(np.uint8)
    ax[i, 0].imshow(MRI[:,:,0],cmap = 'gray')
    ax[i, 1].imshow(CT[:,:,0],cmap = 'gray')
plt.show()

In [None]:
class ReflectionPadding2D(layers.Layer):
    """Implements Reflection Padding as a layer.

    Args:
        padding(tuple): Amount of padding for the
        spatial dimensions.

    Returns:
        A padded tensor with the same type as the input tensor.
    """

    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def call(self, input_tensor, mask=None):
        padding_width, padding_height = self.padding
        padding_tensor = [
            [0, 0],
            [padding_height, padding_height],
            [padding_width, padding_width],
            [0, 0],
        ]
        return tf.pad(input_tensor, padding_tensor, mode="REFLECT")


def residual_block(
    x,
    activation,
    kernel_initializer=kernel_init,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="valid",
    gamma_initializer=gamma_init,
    use_bias=False,
):
    dim = x.shape[-1]
    input_tensor = x

    x = ReflectionPadding2D()(input_tensor)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = activation(x)

    x = ReflectionPadding2D()(x)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.add([input_tensor, x])
    return x


def downsample(
    x,
    filters,
    activation,
    kernel_initializer=kernel_init,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding="same",
    gamma_initializer=gamma_init,
    use_bias=False,
):
    x = layers.Conv2D(
        filters,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x


def upsample(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding="same",
    kernel_initializer=kernel_init,
    gamma_initializer=gamma_init,
    use_bias=False,
):
    x = layers.Conv2DTranspose(
        filters,
        kernel_size,
        strides=strides,
        padding=padding,
        kernel_initializer=kernel_initializer,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x

In [None]:
def get_resnet_generator(
    filters=64,
    num_downsampling_blocks=2,
    num_residual_blocks=9,
    num_upsample_blocks=2,
    gamma_initializer=gamma_init,
    name=None,
):
    img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
    x = ReflectionPadding2D(padding=(3, 3))(img_input)
    x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
        x
    )
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.Activation("relu")(x)

    # Downsampling
    for _ in range(num_downsampling_blocks):
        filters *= 2
        x = downsample(x, filters=filters, activation=layers.Activation("relu"))

    # Residual blocks
    for _ in range(num_residual_blocks):
        x = residual_block(x, activation=layers.Activation("relu"))

    # Upsampling
    for _ in range(num_upsample_blocks):
        filters //= 2
        x = upsample(x, filters, activation=layers.Activation("relu"))

    # Final block
    x = ReflectionPadding2D(padding=(3, 3))(x)
    x = layers.Conv2D(3, (7, 7), padding="valid")(x)
    x = layers.Activation("tanh")(x)

    model = keras.models.Model(img_input, x, name=name)
    return model

In [None]:
def get_discriminator(
    filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None
):
    img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
    x = layers.Conv2D(
        filters,
        (4, 4),
        strides=(2, 2),
        padding="same",
        kernel_initializer=kernel_initializer,
    )(img_input)
    x = layers.LeakyReLU(0.2)(x)

    num_filters = filters
    for num_downsample_block in range(3):
        num_filters *= 2
        if num_downsample_block < 2:
            x = downsample(
                x,
                filters=num_filters,
                activation=layers.LeakyReLU(0.2),
                kernel_size=(4, 4),
                strides=(2, 2),
            )
        else:
            x = downsample(
                x,
                filters=num_filters,
                activation=layers.LeakyReLU(0.2),
                kernel_size=(4, 4),
                strides=(1, 1),
            )

    x = layers.Conv2D(
        1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer
    )(x)

    model = keras.models.Model(inputs=img_input, outputs=x, name=name)
    return model


# Get the generators
gen_G = get_resnet_generator(name="generator_G")
gen_F = get_resnet_generator(name="generator_F")

# Get the discriminators
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        generator_G,
        generator_F,
        discriminator_X,
        discriminator_Y,
        lambda_cycle=10.0,
        lambda_identity=0.5,
    ):
        super(CycleGan, self).__init__()
        self.gen_G = generator_G
        self.gen_F = generator_F
        self.disc_X = discriminator_X
        self.disc_Y = discriminator_Y
        self.lambda_cycle = lambda_cycle
        self.lambda_identity = lambda_identity

    def compile(
        self,
        gen_G_optimizer,
        gen_F_optimizer,
        disc_X_optimizer,
        disc_Y_optimizer,
        gen_loss_fn,
        disc_loss_fn,
    ):
        super(CycleGan, self).compile()
        self.gen_G_optimizer = gen_G_optimizer
        self.gen_F_optimizer = gen_F_optimizer
        self.disc_X_optimizer = disc_X_optimizer
        self.disc_Y_optimizer = disc_Y_optimizer
        self.generator_loss_fn = gen_loss_fn
        self.discriminator_loss_fn = disc_loss_fn
        self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
        self.identity_loss_fn = keras.losses.MeanAbsoluteError()

    def train_step(self, batch_data):
        # x is Horse and y is zebra
        real_x, real_y = batch_data

        # For CycleGAN, we need to calculate different
        # kinds of losses for the generators and discriminators.
        # We will perform the following steps here:
        #
        # 1. Pass real images through the generators and get the generated images
        # 2. Pass the generated images back to the generators to check if we
        #    we can predict the original image from the generated image.
        # 3. Do an identity mapping of the real images using the generators.
        # 4. Pass the generated images in 1) to the corresponding discriminators.
        # 5. Calculate the generators total loss (adverserial + cycle + identity)
        # 6. Calculate the discriminators loss
        # 7. Update the weights of the generators
        # 8. Update the weights of the discriminators
        # 9. Return the losses in a dictionary

        with tf.GradientTape(persistent=True) as tape:
            # Horse to fake zebra
            fake_y = self.gen_G(real_x, training=True)
            # Zebra to fake horse -> y2x
            fake_x = self.gen_F(real_y, training=True)

            # Cycle (Horse to fake zebra to fake horse): x -> y -> x
            cycled_x = self.gen_F(fake_y, training=True)
            # Cycle (Zebra to fake horse to fake zebra) y -> x -> y
            cycled_y = self.gen_G(fake_x, training=True)

            # Identity mapping
            same_x = self.gen_F(real_x, training=True)
            same_y = self.gen_G(real_y, training=True)

            # Discriminator output
            disc_real_x = self.disc_X(real_x, training=True)
            disc_fake_x = self.disc_X(fake_x, training=True)

            disc_real_y = self.disc_Y(real_y, training=True)
            disc_fake_y = self.disc_Y(fake_y, training=True)

            # Generator adverserial loss
            gen_G_loss = self.generator_loss_fn(disc_fake_y)
            gen_F_loss = self.generator_loss_fn(disc_fake_x)

            # Generator cycle loss
            cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle
            cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle

            # Generator identity loss
            id_loss_G = (
                self.identity_loss_fn(real_y, same_y)
                * self.lambda_cycle
                * self.lambda_identity
            )
            id_loss_F = (
                self.identity_loss_fn(real_x, same_x)
                * self.lambda_cycle
                * self.lambda_identity
            )

            # Total generator loss
            total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
            total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F

            # Discriminator loss
            disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x)
            disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y)

        # Get the gradients for the generators
        grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables)
        grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables)

        # Get the gradients for the discriminators
        disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
        disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)

        # Update the weights of the generators
        self.gen_G_optimizer.apply_gradients(
            zip(grads_G, self.gen_G.trainable_variables)
        )
        self.gen_F_optimizer.apply_gradients(
            zip(grads_F, self.gen_F.trainable_variables)
        )

        # Update the weights of the discriminators
        self.disc_X_optimizer.apply_gradients(
            zip(disc_X_grads, self.disc_X.trainable_variables)
        )
        self.disc_Y_optimizer.apply_gradients(
            zip(disc_Y_grads, self.disc_Y.trainable_variables)
        )

        return {
            "G_loss": total_loss_G,
            "F_loss": total_loss_F,
            "D_X_loss": disc_X_loss,
            "D_Y_loss": disc_Y_loss,
        }


In [None]:
class GANMonitor(keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=4):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(4, 2, figsize=(12, 12))
        for i, img in enumerate(train_MRI.take(self.num_img)):
            prediction = self.model.gen_G(img)[0].numpy()
            prediction = ((prediction + 1.0) * 127.5).astype(np.uint8)
            img = ((img[0] + 1.0) * 127.5).numpy().astype(np.uint8)

            ax[i, 0].imshow(img[:,:,0],cmap='gray')
            ax[i, 1].imshow(prediction[:,:,0],cmap='gray')
            ax[i, 0].set_title("Input image")
            ax[i, 1].set_title("Translated image")
            ax[i, 0].axis("off")
            ax[i, 1].axis("off")

            prediction = keras.preprocessing.image.array_to_img(prediction)
        plt.show()
        plt.close()

In [None]:
# Loss function for evaluating adversarial loss
adv_loss_fn = keras.losses.MeanSquaredError()
#checkpoint_path = "/content/drive/MyDrive/GAN_IM/checkpoints.{epoch:03d}"
# Define the loss function for the generators
def generator_loss_fn(fake):
    fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
    return fake_loss


# Define the loss function for the discriminators
def discriminator_loss_fn(real, fake):
    real_loss = adv_loss_fn(tf.ones_like(real), real)
    fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss) * 0.5


# Create cycle gan model
cycle_gan_model = CycleGan(
    generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)

# Compile the model
cycle_gan_model.compile(
    gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    gen_loss_fn=generator_loss_fn,
    disc_loss_fn=discriminator_loss_fn,
)
# Callbacks
plotter = GANMonitor()
checkpoint_filepath = "./content/drive/MyDrive/GAN_IM/checkpoints.{epoch:03d}"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath
)

cycle_gan_model.fit(
    tf.data.Dataset.zip((train_MRI, train_CT)),
    epochs= 20,
)



In [None]:
mod = cycle_gan_model.gen_G

mod.save("/content/drive/MyDrive/GAN_IM/checkpoints/Keras_Test2")

In [None]:
gen = ("/content/drive/MyDrive/GAN_IM/checkpoints/Keras_Test2")
model = keras.models.load_model(gen)

In [None]:
def pipe_test(path,type):
    sample_limit = 0 
    thresh = 1                    
    if path.endswith('.gz'):  
      sample = nib.load(path).get_fdata()
      samp_shape = np.shape(sample)
      temp_array = np.zeros((samp_shape[2],samp_shape[0],samp_shape[1],3))
      #loads data and creates temparary array to store data
      for i in range(samp_shape[2]-2):
          if type == 0:     #for MRI images
              temp_array[i,:,:,0:2] = preprocess_MRI(sample[:,:,i:i+2])
          if type == 1:     #for CT images
              temp_array[i,:,:,0:2] = preprocess_CT(sample[:,:,i:i+2])

      if sample_limit == 0:
          combined_array = temp_array
      else:
          combined_array = np.append(combined_array,temp_array,axis=0)
          #appends the temparary array into array that holds all samples


    train_data = tf.data.Dataset.from_tensor_slices(combined_array)
    #creates tensorflow dataset from array

    tf_data = train_data.map(do_nothing, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)
    #maps the data for training

    return tf_data

In [None]:
example_filename_CT = '/content/drive/My Drive/GAN_IM/CT_Pair/030_Ax_T1_BRAVO_Stealth_CT.nii.gz'
example_filename_MRI = '/content/drive/My Drive/GAN_IM/MR_Pair/030_Ax_T1_BRAVO_Stealth.nii.gz'
test_MRI = pipe_test(example_filename_MRI,0)
test_CT = pipe_test(example_filename_CT,1)

In [None]:
i = 0
for image_x, image_y in tf.data.Dataset.zip((test_MRI, test_CT)):
  _, ax = plt.subplots(1, 3,figsize = (18,12))
  pred = model(image_x)[0].numpy()
  prediction = ((pred + 1.0) * 127.5).astype(np.uint8)
  img = ((image_x[0] + 1.0) * 127.5).numpy().astype(np.uint8)
  img1 = ((image_y[0] + 1.0) * 127.5).numpy().astype(np.uint8)

  ax[0].imshow(img[:,:,0],cmap='gray')
  ax[1].imshow(prediction[:,:,0],cmap='gray')
  ax[2].imshow(img1[:,:,0],cmap='gray')

  ax[0].set_title("Input image")
  ax[1].set_title("Translated image")
  ax[2].set_title("True CT image")

  ax[0].axis("off")
  ax[1].axis("off")
  ax[2].axis("off")
  if i == 229:
    break
  i = i + 1
  plt.savefig('/content/drive/MyDrive/GAN_IM/Keras_20e2_Images/keras_50e_' + str(i) + '.png')
  plt.show()
