# Controllable generation of Chest X-rays
In this notebook, we use the trained GAN model.

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import random
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import ipywidgets as widgets
from functools import partial

In [2]:
seed = 2022
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

In [3]:
labels_target = ['Cardiomegaly', 'Infiltration', 'Effusion', 'Male', 'AP']
labels_target_ = ['No Cardiomegaly', 'No Infiltration', 'No Effusion', 'Female', 'PA']

## Loading the trained models

At the end of training, we saved the weights of both generator and dicriminator networks. We can load the weights and evaluate the model. In general, after training is over, we only need to keep the generator. In this case, however, the trained discriminator could be repurposed as a classifier.

In [4]:
class chestGAN(tf.keras.Model):
    
    def __init__(self, discriminator, generator, latent_dim=512, num_labels=0):
        
        super(chestGAN, self).__init__()
        
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.num_labels = num_labels
        self.gen_loss_tracker = tf.keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = tf.keras.metrics.Mean(name="discriminator_loss")
        self.clf_loss_tracker = tf.keras.metrics.Mean(name='classifier_loss"')
        self.ssim_tracker = tf.keras.metrics.Mean(name="ssim")
        
        self.config = {
            'lambda_gp': 0,
            'lambda_drift': 0,
            'lambda_clf': 0,
            'd_steps': 1,
        }
        

    @property
    def metrics(self):
        
        if self.config['lambda_clf'] > 0:
            return [self.gen_loss_tracker, self.disc_loss_tracker, self.clf_loss_tracker, self.ssim_tracker]
        
        return [self.gen_loss_tracker, self.disc_loss_tracker, self.ssim_tracker]
    
        
    def compile(self, d_optimizer, g_optimizer, loss_fn, clf_loss_fn=None, mode='dcgan'):
        
        super(chestGAN, self).compile()
        
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.clf_loss_fn = clf_loss_fn
        self.mode = mode        
            
        if self.mode.find('wgan') >= 0:
            assert self.loss_fn.name == "wasserstein", "Please compile with the Wassertein Loss!"
            self.config['d_steps'] = 5
            self.config['lambda_drift'] = 1e-3
            if self.mode.find('gp') >= 0:
                self.config['lambda_gp'] = 10
        else:
            assert self.loss_fn.name != "wasserstein", "Please compile with a different adversarial loss!"

        if self.mode.find('acgan') < 0:
            assert not self.num_labels, 'mode and num_labels are incompatible'
        else:
            assert self.num_labels, 'mode and num_labels are incompatible'
            self.config['lambda_clf'] = 10
                
        print('Training configuration:', self.config)
       
    
    def get_labels(self, predictions, real=True):
        if self.loss_fn.name == "wasserstein":
            if real:
                labels = -tf.ones_like(predictions)
            else:
                labels = tf.ones_like(predictions)
        else:
            if real:
                labels = tf.ones_like(predictions)
            else:
                labels = tf.zeros_like(predictions)
        return labels


    def train_step(self, data):
        
        lambda_gp = self.config['lambda_gp']
        lambda_drift = self.config['lambda_drift']
        lambda_clf = self.config['lambda_clf'] 
        d_steps = self.config['d_steps']    
        
        real_images, target_labels = data
        
        prompt = tf.expand_dims(tf.expand_dims(tf.cast(target_labels, tf.float32), axis=1), axis=2)
        
        batch_size = tf.shape(real_images)[0]
        
        # Train discriminator
        
        for _ in range(d_steps):
            
            z = tf.random.normal(shape=(batch_size, 1, 1, self.latent_dim))
            gen_in = tf.concat([z, prompt], axis=-1) if lambda_clf > 0 else z
            generated_images = self.generator(gen_in)

            with tf.GradientTape() as tape:

                predictions_real = self.discriminator(real_images)
                predictions_gen = self.discriminator(generated_images)

                labels_real = self.get_labels(predictions_real[:, 0])
                labels_gen = self.get_labels(predictions_gen[:, 0], False)

                d_loss_real = self.loss_fn(labels_real, predictions_real[:, 0])
                d_loss_gen = self.loss_fn(labels_gen, predictions_gen[:, 0])
                
                d_loss = d_loss_real + d_loss_gen
                
                if lambda_gp > 0:

                    epsilon = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
                    interpolates = (1-epsilon)* real_images + epsilon * generated_images

                    with tf.GradientTape() as gp_tape:

                        gp_tape.watch(interpolates)
                        predictions = self.discriminator(interpolates)[:, 0]

                    grads_gp = gp_tape.gradient(predictions, [interpolates])[0]

                    norm = tf.sqrt(tf.reduce_sum(tf.square(grads_gp), axis=[1, 2, 3]))
                    gradient_penalty = tf.reduce_mean((norm - 1.0) ** 2)   

                    d_loss += lambda_gp * gradient_penalty 
                
                if lambda_drift > 0:
                    
                    drift_penalty = tf.reduce_mean(predictions_real[:, 0] ** 2)
                    d_loss += lambda_drift * drift_penalty 
                    
                if lambda_clf > 0:
                    
                    clf_loss_disc = self.clf_loss_fn(target_labels, predictions_real[:, 1:]) \
                                    + self.clf_loss_fn(target_labels, predictions_gen[:, 1:])
                    d_loss += lambda_clf * clf_loss_disc


            grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(
                zip(grads, self.discriminator.trainable_weights)
            )
            
            self.disc_loss_tracker.update_state(d_loss)
            if lambda_clf > 0:
                self.clf_loss_tracker.update_state(clf_loss_disc)
        
        # Train generator

        z = tf.random.normal(shape=(batch_size, 1, 1, self.latent_dim))
        gen_in = tf.concat([z, prompt], axis=-1) if lambda_clf > 0 else z

        with tf.GradientTape() as tape:
            
            fake_images = self.generator(gen_in)
            predictions = self.discriminator(fake_images)
            fake_labels = self.get_labels(predictions[:, 0])
            g_loss = self.loss_fn(fake_labels, predictions[:, 0])
            
            if lambda_clf > 0:
                clf_loss_gen = self.clf_loss_fn(target_labels, predictions[:, 1:])
                g_loss += lambda_clf * clf_loss_gen
                
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(
            zip(grads, self.generator.trainable_weights)
        )
        
        self.gen_loss_tracker.update_state(g_loss)
        if lambda_clf > 0:
            self.clf_loss_tracker.update_state(clf_loss_gen)
        
        # Compute Structural Similarity Index between real and fake images
        ssim = tf.reduce_mean(tf.image.ssim((real_images + 1) / 2, (fake_images + 1) / 2, 1), axis=0)            
        self.ssim_tracker.update_state(ssim)

        if lambda_clf > 0:
            return {
                "g_loss": self.gen_loss_tracker.result(),
                "d_loss": self.disc_loss_tracker.result(),
                "clf_loss": self.clf_loss_tracker.result(),
                "ssim": self.ssim_tracker.result()
            }
        
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
            "ssim": self.ssim_tracker.result()
        }
    
    def show_fake(self, path='', num_images=4, label=None, z=[]):
        
        if not label:
            label = self.num_labels * [0]  # Healthy female - PA view
        elif self.num_labels:
            assert len(label) == self.num_labels
            
        plt.figure(figsize=(num_images*4, 4), facecolor='k')
        plt.axis('off')
        
        prompt = tf.reshape(tf.cast(label, tf.float32), (1, -1))
        prompt = tf.expand_dims(tf.expand_dims(prompt, axis=1), axis=2)
        prompt = tf.tile(prompt, [num_images, 1, 1 ,1])
        
        if len(z) == 0:
            z = tf.random.normal(shape=(num_images, 1, 1, self.latent_dim))
        gen_in = tf.concat([z, prompt], axis=-1) if self.num_labels else z
        fake_images = self.generator(gen_in)
        fake_images = (fake_images + 1) / 2
        fake_images = tf.concat([fake_images[i,:,:,0] for i in range(num_images)], axis=1)
        plt.imshow(fake_images, cmap='gray')
        plt.savefig(path)

In [5]:
num_labels = len(labels_target)

In [6]:
discriminator = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, 3, 2, padding='same', activation=tf.keras.layers.LeakyReLU(0.2), input_shape= (64, 64, 1)),
    tf.keras.layers.LayerNormalization(axis=[1,2,3]),
    tf.keras.layers.Conv2D(32, 3, 2, padding='same', activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.LayerNormalization(axis=[1,2,3]),
    tf.keras.layers.Conv2D(64, 3, 2, padding='same', activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.LayerNormalization(axis=[1,2,3]),
    tf.keras.layers.Conv2D(128, 3, 2, padding='same', activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.LayerNormalization(axis=[1,2,3]),
    tf.keras.layers.Conv2D(256, 3, 2, padding='same', activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.LayerNormalization(axis=[1,2,3]),
    tf.keras.layers.Conv2D(512, 3, 2, padding='same', activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.LayerNormalization(axis=[1,2,3]),
    tf.keras.layers.Conv2D(1+num_labels, 3, 2, padding='same', activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.GlobalAveragePooling2D()
])

In [7]:
generator = tf.keras.Sequential([
    tf.keras.layers.Conv2DTranspose(512, 3, 2, padding='same', activation='relu', input_shape= (1, 1, 512+num_labels)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(256, 3, 2, padding='same', activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(128, 3, 2, padding='same', activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(64, 3, 2, padding='same', activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(32, 3, 2, padding='same', activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(1, 3, 2, padding='same', activation='tanh'),
])

In [8]:
chest_gan = chestGAN(
    discriminator=discriminator, generator=generator, num_labels=num_labels
)

In [9]:
chest_gan.generator.load_weights("../input/train-gan-chest-xrays-64/generator.h5")
chest_gan.discriminator.load_weights("../input/train-gan-chest-xrays-64/discriminator.h5")

## Evaluating the GAN

We can tell by visual inspection of fake samples that the model didn't reach the quality of the original dataset. However, it managed to learn the basic structure of a chest radiograph (black corners, dark lungs, etc.). It fails to capture the small scale objects such as bones. For this reason and more, more complex models are usually used instead of the `WGAN-GP`. For example, deeper networks would take longer to train but can pick up on more patterns if properly trained. Such models include [StyleGAN](https://arxiv.org/pdf/1812.04948) (and its successors) and [BigGAN](https://arxiv.org/abs/1809.11096).

In [10]:
chest_gan.show_fake("healthy-4samples.png", label=[0, 0, 0, 0, 0], num_images=4)

In [11]:
chest_gan.show_fake("healthy-20.png", label=[0, 0, 0, 0, 0], num_images=20)

We want to see how much the GAN managed to capture the label information provided during training.

In [12]:
def evaluate(gan, labels_target, labels_target_, num_samples=100):
        
        threshold = 0.5
        
        for i, label_name in enumerate(labels_target):
            
            for target in [0, 1]:

                label = gan.num_labels * [0]

                if target:
                    label[i] = 1
                    name = label_name  
                else:
                    name = labels_target_[i]  
                
                prompt = tf.reshape(tf.cast(label, tf.float32), (1, -1))
                prompt = tf.expand_dims(tf.expand_dims(prompt, axis=1), axis=2)
                prompt = tf.tile(prompt, [num_samples, 1, 1 ,1])
                
                z = tf.random.normal(shape=(num_samples, 1, 1, gan.latent_dim))
                gen_in = tf.concat([z, prompt], axis=-1) if gan.num_labels else z
                fake_images = gan.generator(gen_in)
                predictions = gan.discriminator(fake_images)
                proba = tf.math.sigmoid(predictions[:, 1+i])
                if not target:
                    proba = 1 - proba
                compare = tf.cast(proba > threshold, tf.float32)
                print(f"Label '{name}' predicted correctly {tf.reduce_mean(compare)*100:.3f}% of the time.")

In [13]:
evaluate(chest_gan, labels_target, labels_target_, num_samples=1000)

Finally, we would like to see the effects of each label on each image.

In [14]:
num_images = 4
z = tf.random.normal(shape=(num_images, 1, 1, chest_gan.latent_dim))

def reset_values(sliders, b):
    global z
    for slider in sliders:
        slider.value = 0
    z = tf.random.normal(shape=(num_images, 1, 1, chest_gan.latent_dim))

def show_label_effect(a, b, c, d, e):
    chest_gan.show_fake(num_images=num_images, label=[a, b, c, e, d], z=z)

layout = widgets.Layout(display='flex', flex_flow='column', align_items='center', width='50%')
slider_a = widgets.FloatSlider(min=0, max=1.0, step=0.1, description='Cardiomegaly:', readout_format='.1f', style = {'description_width': 'initial'})
slider_b = widgets.FloatSlider(min=0, max=1.0, step=0.1, description='Infiltration:', readout_format='.1f')
slider_c = widgets.FloatSlider(min=0, max=1.0, step=0.1, description='Effusion:', readout_format='.1f')
slider_d = widgets.FloatSlider(min=0, max=1.0, step=0.1, description='Male:', readout_format='.1f')
slider_e = widgets.FloatSlider(min=0, max=1.0, step=0.1, description='AP:', readout_format='.1f')
sliders_list = [slider_a, slider_b, slider_c, slider_d, slider_e]
reset_button = widgets.Button(description = "Reset")
reset_button.on_click(partial(reset_values, sliders_list))

out = widgets.interactive_output(show_label_effect, {'a': slider_a, 'b': slider_b, 'c': slider_c, 'd': slider_d, 'e': slider_e})
widgets.Box([widgets.VBox([slider_a, slider_b, slider_c, slider_d, slider_e, reset_button], layout=layout), out])

You can see that by increasing the value corresponding to `Cardiomegaly`, the heart is enlarged.