## <center>Conditional WGAN with Gradient Penality</center>

Import tensorflow and enable eager execution mode. Also check GPU availability

In [None]:
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
print("You are using tensorflow version "+str(tf.__version__)+" in eager-execution mode")

In [None]:
print("Checking for GPU availability...")
if tf.test.is_gpu_available():
    print("--> Valid GPU of type "+str(tf.test.gpu_device_type())+" with name "+tf.test.gpu_device_name()+" found!")
else:
    print("--> No GPU detected!")

In [None]:
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import time
from IPython import display

### Import dataset

In [None]:
(train_x, train_y), (_, _) = tf.keras.datasets.mnist.load_data()

### Preprocess the dataset

In [None]:
train_x = train_x.reshape(train_x.shape[0], 28, 28, 1).astype('float32')
train_x = (train_x - 127.5) / 127.5  # Normalize the images to [-1, 1]

In [None]:
BUFFER_SIZE = 60000
IMG_WIDTH = 28
IMG_HEIGHT = 28
IMG_CHANNEL = 1
n_classes = 10
BATCH_SIZE = 300 # 256
LAMBDA = 10 # For gradient penalty
N_CRITIC = 3 # Train critic(discriminator) n times then train generator 1 time.

noise_dim = 100
EPOCHS = 50
num_examples_to_generate = 16

#### Before you create a tensor slice for training you need to attach the one_hot labels to the training data

First concatenate the one_hot_labels with the real images and then later you will concat the one_hot_labels with the noise

In [None]:
one_hot_labels = tf.one_hot(train_y, n_classes)

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((train_x))#, one_hot_labels))
dataset = dataset.shuffle(buffer_size=BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
print(f"Shape of training images: {train_x.shape}")
print(f"Shape of training labels: {one_hot_labels.shape}")
print("Example: ")
plt.imshow(train_x[9])
print(train_y[9])

In [None]:
generator_in_channels = noise_dim #+ n_classes
discriminator_in_channels = IMG_CHANNEL #+ n_classes
print(generator_in_channels, discriminator_in_channels)

### Create the models

In [None]:
def make_generator_model():
    
    gen_input_img = tf.keras.Input(shape=(generator_in_channels), batch_size=BATCH_SIZE)
    dense_0 = tf.keras.layers.Dense(7*7*generator_in_channels, use_bias=False, input_shape=(generator_in_channels,))(gen_input_img)
    bt_norm_0 = tf.keras.layers.BatchNormalization()(dense_0)
    leaky_0 = tf.keras.layers.LeakyReLU()(bt_norm_0)
    
    reshaped_0 = tf.keras.layers.Reshape((7,7,generator_in_channels))(leaky_0)

    assert reshaped_0.shape == (BATCH_SIZE,7,7,generator_in_channels)
    
    conv2dT_0 = tf.keras.layers.Conv2DTranspose(128, (5,5), strides=(1,1), padding='same', use_bias=False)(reshaped_0)
    assert conv2dT_0.shape == (BATCH_SIZE,7,7,128)
    bt_norm_1 = tf.keras.layers.BatchNormalization()(conv2dT_0)
    leaky_1 = tf.keras.layers.LeakyReLU()(bt_norm_1)
    
    conv2dT_1 = tf.keras.layers.Conv2DTranspose(64, (5,5), strides=(2,2), padding='same', use_bias=False)(leaky_1)
    assert conv2dT_1.shape == (BATCH_SIZE,14,14,64)
    bt_norm_2 = tf.keras.layers.BatchNormalization()(conv2dT_1)
    leaky_2 = tf.keras.layers.LeakyReLU()(bt_norm_2)
    
    output_img = tf.keras.layers.Conv2DTranspose(discriminator_in_channels, (5,5), strides=(2,2), padding='same', use_bias=False)(leaky_2)
    assert output_img.shape == (BATCH_SIZE, IMG_WIDTH, IMG_HEIGHT, discriminator_in_channels)
    
    model = tf.keras.Model(inputs=gen_input_img, outputs=output_img)
    
    return model

In [None]:
generator = make_generator_model()

noise = tf.random.normal([1, generator_in_channels])
test_generated_image = generator(noise, training=False)

plt.imshow(test_generated_image[0, :, :, 0], cmap='gray')
# generator.summary()


In [None]:
def make_discriminator_model():
    disc_input_img = tf.keras.Input(shape=(IMG_WIDTH,IMG_HEIGHT, discriminator_in_channels), batch_size=BATCH_SIZE)
    conv2d_0 = tf.keras.layers.Conv2D(64, (5,5), strides=(2,2), padding='same',
                                     input_shape=[IMG_WIDTH, IMG_HEIGHT, discriminator_in_channels])(disc_input_img)
    leaky_0 = tf.keras.layers.LeakyReLU()(conv2d_0)
    drop_0 = tf.keras.layers.Dropout(0.3)(leaky_0)
    
    conv2d_1 = tf.keras.layers.Conv2D(128, (5,5), strides=(2,2), padding='same')(drop_0)
    leaky_1 = tf.keras.layers.LeakyReLU()(conv2d_1)
    drop_1 = tf.keras.layers.Dropout(0.3)(leaky_1)
    
    flatten_0 = tf.keras.layers.Flatten()(drop_1)
    dense_0 = tf.keras.layers.Dense(1)(flatten_0)
    
    model = tf.keras.Model(inputs=disc_input_img, outputs=dense_0)
    
    return model

In [None]:
discriminator = make_discriminator_model()
test_decision = discriminator(test_generated_image)
print(test_decision)
# discriminator.summary()

## Define optimizers

Define loss functions and optimizers for both models.


In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

### Save checkpoints

In [None]:
generator = make_generator_model()
discriminator = make_discriminator_model()


ckpt_dir = 'training_ckpts'
ckpt_prefix = os.path.join(ckpt_dir, 'ckpt')
ckpt = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                          discriminator_optimizer=discriminator_optimizer,
                          generator=generator,
                          discriminator=discriminator)

In [None]:
seed = tf.random.normal([num_examples_to_generate, generator_in_channels])

In [None]:
@tf.function
def train_g_step(images):
    noise = tf.random.normal([BATCH_SIZE, generator_in_channels])
    
    with tf.GradientTape(persistent=True) as g_tape:
        generated_image = generator(noise, training=True)
        generated_pred = discriminator(generated_image, training=True)
        # Calculate the loss
        G_loss = -tf.reduce_mean(generated_pred)
        
    # Calculate the gradient
    G_grad = g_tape.gradient(G_loss, generator.trainable_variables)
    # Apply gradient
    generator_optimizer.apply_gradients(zip(G_grad, generator.trainable_variables))
            

In [None]:
@tf.function
def train_d_step(images):
    noise = tf.random.normal([BATCH_SIZE, generator_in_channels])
    epsilon = tf.random.uniform(shape=[BATCH_SIZE, 1, 1, 1], minval=0, maxval=1)
    
    with tf.GradientTape(persistent=True) as d_tape:
        with tf.GradientTape() as gp_tape:
            generated_image = generator(noise, training=True)
            interpolated_image = epsilon * tf.dtypes.cast(images, tf.float32) + ((1 - epsilon) * generated_image)
            interpolated_pred = discriminator(interpolated_image, training=True)
            
        # Compute the gradient penality
        grads = gp_tape.gradient(interpolated_pred, interpolated_image)
        grad_norms = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gradient_penalty = tf.reduce_mean(tf.square(grad_norms - 1))
        
        generated_pred = discriminator(generated_image, training=True)
        real_pred = discriminator(images, training=True)
        
        # Calculate the D_loss
        D_loss = tf.reduce_mean(generated_pred) - tf.reduce_mean(real_pred) + LAMBDA * gradient_penalty
    
    # Calculate the gradient
    D_grad = d_tape.gradient(D_loss, discriminator.trainable_variables)
    # Apply the gradients to the optimizer
    discriminator_optimizer.apply_gradients(zip(D_grad, discriminator.trainable_variables))
            

In [None]:
def train(dataset, epochs):
    n_critic_count = 0
    
    for epoch in range(epochs):
        start = time.time()
        
        for image_batch in dataset:
            
            # Train the critic/discriminator
            train_d_step(image_batch)
            
            n_critic_count += 1
            if n_critic_count >= N_CRITIC:
                # Train the generator per N_CRITIC discriminator train steps
                train_g_step(image_batch)
                n_critic_count = 0
            
        # Produce images for GIF as you train
        display.clear_output(wait=True)
        generate_and_save_images(generator, 
                                 epoch + 1,
                                seed)
        # Save model every 15 epochs
        if (epoch + 1) % 15 == 0:
            ckpt.save(file_prefix = ckpt_prefix)
            
        print('Time for epoch {} is {} sec'.format(epoch+1, time.time()-start))
        
    # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                            epochs,
                            seed)

In [None]:
def generate_and_save_images(model, epoch, test_input):
    # You have to set training = False, because it is inference
    predictions = model(test_input, training=False)
    
    fig = plt.figure(figsize=(4,4))
    
    for i in range(predictions.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow(predictions[i,:,:,0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
        
    plt.savefig('epoch_images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [None]:
train(dataset, EPOCHS)

In [None]:
ckpt.restore(tf.train.latest_checkpoint(ckpt_dir))

### Create GIF

In [None]:
def display_image(epoch_no):
    return PIL.Image.open('epoch_images/image_at_epoch_{:04d}.png'.format(epoch_no))

In [None]:
display_image(3)

Use imageio to create an animated gif using the images saved during training.

In [None]:
anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob('image*.png')
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

In [None]:
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)