# Wasserstein GAN (W-GAN)

Originally proposed by [Arjovsky et al.](https://arxiv.org/pdf/1701.07875.pdf) is their work titled Unsupervised Representation Learning With Deep Convolutions Generative Adversarial Networks. This network uses a basic implementation where generator and discriminator models use convolutional layers, batch normalization and Upsampling.
This notebook trains both networks using ADAM optimizer to play the minimax game. We showcase the effectiveness using MNIST digit generation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PacktPublishing/Hands-On-Generative-AI-with-Python-and-TensorFlow-2/blob/master/Chapter_6/wasserstein_gan.ipynb)

## Load Libraries

In [None]:
from tensorflow.keras import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.optimizers import Adam,RMSprop
from tensorflow.keras import datasets
import numpy as np


import keras
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, MaxPooling2D, GlobalAveragePooling2D, Dense, Conv2DTranspose, Flatten, LeakyReLU, Reshape
from tensorflow.keras import Model
import matplotlib.pyplot as plt
import cv2
from tensorflow.keras.models import Sequential
import os
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from skimage import exposure
from IPython.display import clear_output

## Load Utility Functions

In [None]:
from gan_utils import build_critic
from gan_utils import build_dc_generator
from gan_utils import sample_images
from gan_utils import wasserstein_loss

In [None]:
"""
Linux Paths for CheXpert Dataset

train_dir = os.path.abspath("/media/nicholasjprimiano/8A5C72285C720F67/ML_C/CheXpert/CheXpert-Keras-master/data/default_split/CheXpert-v1.0-small/CheXpert-v1.0-small/train.csv")
traindf=pd.read_csv(train_dir, dtype=str)

valid_dir = os.path.abspath("/media/nicholasjprimiano/8A5C72285C720F67/ML_C/CheXpert/CheXpert-Keras-master/data/default_split/CheXpert-v1.0-small/CheXpert-v1.0-small/valid.csv")
validdf=pd.read_csv(valid_dir, dtype=str)

for i in range(len(traindf)):
    traindf.iloc[i,0] = "/media/nicholasjprimiano/8A5C72285C720F67/ML_C/CheXpert/CheXpert-Keras-master/data/default_split/CheXpert-v1.0-small/" + traindf.iloc[i,0]"""
    
#Windows Paths for CheXpert Dataset
train_dir = os.path.abspath(r"C:/ML_C/CheXpert/CheXpert-Keras-master/data/default_split/CheXpert-v1.0-small/CheXpert-v1.0-small/train.csv")
traindf=pd.read_csv(train_dir, dtype=str)

#Modify dataframe path
for i in range(len(traindf)):
    traindf.iloc[i,0] = r"C:/ML_C/CheXpert/CheXpert-Keras-master/data/default_split/CheXpert-v1.0-small/" + traindf.iloc[i,0]

#valid_dir = os.path.abspath(r"C:/ML_C/CheXpert/CheXpert-Keras-master/data/default_split/CheXpert-v1.0-small/CheXpert-v1.0-small/valid.csv")
#validdf=pd.read_csv(valid_dir, dtype=str)

#Only looking at AP (anterior-posterior) view xrays
aptrainlist = []
for i in range(len(traindf)):
    if (traindf.iloc[i,4] == "AP"):
        aptrainlist.append(traindf.iloc[i,:])

aptraindf = pd.DataFrame(aptrainlist)

#Only looking at xrays labeled Pneumothorax
paths = []
for i in range(len(aptraindf[aptraindf["Pneumothorax"] == "1.0"]["Path"])):
    paths.append(aptraindf[aptraindf["Pneumothorax"] == "1.0"]["Path"].iloc[i])

    #Normalization called in get_imgs() not used right now
def normalize_xray(img):
    hist_normal = exposure.equalize_adapthist(img/np.max(img))   
    #clache_hist_normal = exposure.equalize_adapthist(hist_normal /np.max(hist_normal))
    #return clache_hist_normal
    return hist_normal

#load 128x128 images

IMG_SIZE = 128
def get_imgs(paths):
    images = []
    for i in paths:
        #Normalized
        images.append(normalize_xray(cv2.cvtColor(cv2.resize(cv2.imread(i),(IMG_SIZE,IMG_SIZE)), cv2.COLOR_BGR2GRAY)))
        #Gray Scale 
        #images.append(cv2.cvtColor(cv2.resize(cv2.imread(i),(IMG_SIZE,IMG_SIZE)), cv2.COLOR_BGR2GRAY))
    return images 



## W-GAN Training Loop
- As proposed in the original paper
- Train critic using a mix of fake and real samples
- Calculate discriminator loss
- Train the critic 5 times per training cycle of the generator
- Use Wasserstein_loss for both generator and discriminators
- Fix the discriminator and train generator

In [None]:
def train(generator=None,discriminator=None,gan_model=None,
          epochs=1000, discriminator_cycles=5, batch_size=128, sample_interval=50,
          z_dim=100,clip_value = 0.01):
    # Load MNIST train samples
    #(X_train, _), (_, _) = datasets.mnist.load_data()

    # Rescale -1 to 1
    #X_train = X_train / 127.5 - 1

    #X_train = np.expand_dims(X_train, axis=3)
    #X_train array of images with values between 0 and 1
    
    X_train = np.array(get_imgs(paths)).astype(np.float32)

    #reshaped X train and shifted pixzel values between -1 and 1 for tanh 
    X_train = X_train.reshape(-1, IMG_SIZE, IMG_SIZE, 1) * 2. - 1.
    
    # Prepare GAN output labels
    real_y = -np.ones((batch_size, 1))
    fake_y = np.ones((batch_size, 1))

    for epoch in range(epochs):
        # train disriminator
        for _ in range(discriminator_cycles):
            # pick random real samples from X_train
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            real_imgs = X_train[idx]

            # pick random noise samples (z) from a normal distribution
            noise = np.random.normal(0, 1, (batch_size, z_dim))
            # use generator model to generate output samples
            fake_imgs = generator.predict(noise)

            # calculate discriminator loss on real samples
            disc_loss_real = discriminator.train_on_batch(real_imgs, real_y)

            # calculate discriminator loss on fake samples
            disc_loss_fake = discriminator.train_on_batch(fake_imgs, fake_y)

            # overall discriminator loss
            discriminator_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
            
            # clip weights to ensure adherance to model constraints in EM space
            # Clip critic weights
            for l in discriminator.layers:
                weights = l.get_weights()
                weights = [np.clip(w, -clip_value, clip_value) for w in weights]
                l.set_weights(weights)
        
        #train generator
        # pick random noise samples (z) from a normal distribution
        noise = np.random.normal(0, 1, (batch_size, z_dim))

        # use trained discriminator to improve generator
        gen_loss = gan_model.train_on_batch(noise, real_y)

        # training updates
        print ("%d [Discriminator loss: %f] [Generator loss: %f]" % (epoch,
                                                                     1 - discriminator_loss[0], 
                                                                     1 - gen_loss[0]))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch,generator)

## Prepare Discriminator Model or Critic

In [None]:
discriminator = build_critic()
discriminator.compile(loss=wasserstein_loss,
            optimizer=RMSprop(lr=0.00005),
            metrics=['accuracy'])

## Prepare Generator Model

In [None]:
generator = build_dc_generator()

## Prepare GAN Model

In [None]:
# Noise for generator
z_dim = 100
z = Input(shape=(z_dim,))
img = generator(z)

# Fix the discriminator
discriminator.trainable = False

# Get discriminator output
valid = discriminator(img)

# Stack discriminator on top of generator
gan_model = Model(z, valid)
gan_model.compile(loss=wasserstein_loss,
    optimizer=RMSprop(lr=0.00005),
    metrics=['accuracy'])
gan_model.summary()

## Train W-GAN

In [None]:
train(generator, discriminator, gan_model, epochs=4000, batch_size=64, sample_interval=100)