# Adversarial Generative Network

The intention of this notebook is to create a GAN to generate realistic handwritten digest from the MNIST dataset.

In [None]:
import datetime
import math

import matplotlib.pyplot as plt
import numpy as np
import sklearn.metrics
import seaborn as sns
import tensorflow as tf
from tqdm import tqdm_notebook

from utils.mnist_dataset import MNIST_Dataset
from utils.batch import make_batches_all, make_batches_random

In [None]:
assert(int(tf.__version__[0])==2)  # Use TensorFlow 2

In [None]:
%load_ext tensorboard

In [None]:
# Check computing units
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

## Constants

In [None]:
LOG_PATH = 'log-GAN'  # For tensorboard, etc.

DS_PATH = 'data' # Where de MNIST dataset is located

# Hyperparams
PARAMS = {
    'generator': {
        'learning_rate': 0.00001
    },
    'discriminator': {
        'learning_rate': 0.00001
    },
    'latent_factors': 100,
    'epochs': 200,
    'batch_size': 64,  # Use an even number
    'disc_gen_ratio': 1  # How many times over the generator the discriminator is trained.
}

## Dataset

In [None]:
!bash download_mnist.sh {DS_PATH}

In [None]:
class Dataset():
        
    def __init__(self, ds_path):
        self.ds = MNIST_Dataset(ds_path)
        
    @staticmethod
    def _preprocess_samples(x):
        """x: tensor of shape (-1, 28, 28) representing the images.
        """
        n,w,h = x.shape
        # From 28x28 pixels to 32x32
        x_32 = np.pad(
            x,
            pad_width=((0,0),(2,2),(2,2)),
            mode='constant',
            constant_values=0
        )
        x_scaled = (x_32/128)-1 # Scaled -1,1
        return x_scaled.reshape(n,32,32,1)
    
    
    def train_data(self):
        x,y = self.ds.get_train()
        return (
            self._preprocess_samples(x).astype(np.float32),
            y
        )

## Model architecture

### Generator

This model tries to generate realistic images.
-  INPUT: Noise vector (1x100).
-  OUTPUT: Image (32x32)

In [None]:
def get_generator(latent_factors):
    """ From noise to plausible examples.
    INPUT: Noise vector of latent factors (-1,1)
    OUTPUT: 32x32x1 (-1,1) grayscale image
    """
    return tf.keras.Sequential(
        name="Generator",
        layers=[
            tf.keras.layers.InputLayer(
                input_shape=(latent_factors)
            ),
            tf.keras.layers.Dense(  
                name="D1",
                units=128,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.Dense(  
                name="D2",
                units=256,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.Dense(  
                name="D3",
                units=512,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.BatchNormalization(
                momentum=0.8,
                name="BN1"
            ),
            tf.keras.layers.Dense(  
                name="D4",
                units=1024,
                activation=tf.math.tanh
            ),
            # Reshape 120->1x1x120
            tf.keras.layers.Reshape(
                target_shape=(32,32,1)
            )
        ]
    )

### Discriminator

This model tells the image probability of being real.
-  INPUT: Image (32x32).
-  OUTPUT: Probability (0,1)

In [None]:
def get_discriminator():
    """Output the probability for an example to be real.
    """
    return tf.keras.Sequential(
        name="Discriminator",
        layers=[
            tf.keras.layers.InputLayer(
                input_shape=(32,32,1)
            ),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(  
                name="F3",
                units=512,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.Dense(  
                name="F4",
                units=256,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.Dense(  
                name="F5",
                units=1,
                activation=tf.math.sigmoid
            )
        ]
    )

### Generative Adversarial Network

This model represents the pipeline of the GAN.

In [None]:
def get_gan(generator, discriminator):
    return tf.keras.Sequential(
        name="GAN",
        layers=[
            tf.keras.layers.InputLayer(
                input_shape=(
                    generator.layers[0].input.shape[1]
                )
            ),
            generator,
            discriminator
        ]
    )

### Model objects

In [None]:
generator = get_generator(PARAMS['latent_factors'])
gen_opt = tf.keras.optimizers.Adam(
    learning_rate=PARAMS['generator']['learning_rate']
)

discriminator = get_discriminator()
disc_opt = tf.keras.optimizers.Adam(
    learning_rate=PARAMS['discriminator']['learning_rate']
)

gan = get_gan(generator, discriminator)

generator.summary()
discriminator.summary()
gan.summary()

## Util Functions

In [None]:
def sample_latent_factors(n):
    """Noise sample function.
    """
    return tf.random.uniform(shape=(n, PARAMS['latent_factors']), minval=-1.0, maxval=1.0)  # Naive way.
    #return tf.random.normal(shape=(n, PARAMS['latent_factors']), mean=0.0, stddev=1.0)

In [None]:
def plot_generated_images(epoch, images, path_out, dim=(10,10), figsize=(32,32)):
    epoch = str(epoch).rjust(3,"0")
    generated_images = images.reshape(-1,32,32)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.suptitle('Epoch: {}'.format(epoch), fontsize=75, horizontalalignment='center', verticalalignment='top', backgroundcolor="black", color='yellow', weight='bold')
    plt.savefig(path_out + '/generator_{}.png'.format(epoch))
    plt.close()
    return None

## Operative

In [None]:
def train_step(model, optimizer, fn_loss, x_train, y_train):
    """Training step for a model.
    """
    with tf.GradientTape() as tape:
        predictions = model(x_train, training=True)
        loss = fn_loss(y_train, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

In [None]:
def fn_loss(y_true, y_pred):
    """Loss function.
    Log loss or cross-entropy.
    
    NOTE: it is possible to use tf.keras.losses.BinaryCrossentropy() alone.
    param reduction='SUM_OVER_BATCH_SIZE'  works the same.
    """
    return tf.math.reduce_mean(
        tf.keras.losses.binary_crossentropy(
            y_true=y_true,
            y_pred=y_pred,
            label_smoothing = 0,  # Not useful since I only need one sided label smoothing
            from_logits=False
        )
    )

In [None]:
# Tensorboard writers

log_dir = LOG_PATH + '/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
gen_summary_writer = tf.summary.create_file_writer(log_dir + '/gen')
disc_summary_writer = tf.summary.create_file_writer(log_dir + '/disc')

img_out_dir = log_dir + '/imgs'

In [None]:
!mkdir {img_out_dir}

In [None]:
# Load the dataset

ds = Dataset(DS_PATH)
x_real, _ = ds.train_data()  # I don't care about labels

In [None]:
# Precalculate batch labels to gain speed.

bs = PARAMS['batch_size']
n_batches = math.ceil(len(x_real)*2 / bs) # Real + Fake data / batch_size

# Discriminator labels
disc_y = tf.concat(
    [
        tf.fill((bs//2,1), 0.95), # Real. One sided label smoothing
        tf.fill((bs//2,1), 0.0)   # Fake
    ],
    axis=0
)

# Generator labels
gen_y = tf.fill((bs,1), 1.0)  # Do not smooth generator samples!!

assert(len(disc_y)==len(gen_y))

In [None]:
def evaluate(generator, discriminator, val_data_real, val_data_fake):
    n = len(val_data_real)
    y_true = tf.concat([tf.fill((n,1), 1.0),tf.fill((n,1), 0.0)], axis=0)
    y_pred = discriminator( tf.concat([val_data_real, val_data_fake ],axis=0) )
    y_true_bin = y_true.numpy().flatten()
    y_pred_bin = np.where(y_pred.numpy().flatten()>=0.5, 1, 0)

    # Metrics
    cf = sklearn.metrics.confusion_matrix(y_true, y_pred_bin)
    tn, fp, fn, tp = cf.ravel()
    return {
        'disc_loss': fn_loss(y_true, y_pred),
        'gen_loss': fn_loss(tf.fill((n,1), 1.0), discriminator(val_data_fake)),
        'acc': sklearn.metrics.accuracy_score(y_true, y_pred_bin),
        'fpr': fp/(fp+tn), # False positive ratio
        'cm': cf
    }

## Training

### Tensorboard

We will track the metrics evolution and the generated images with Tensorboard.

In [None]:
%tensorboard --logdir {log_dir}

### Taining loop

In [None]:
""" Training loop.

Use tensorboard to see examples and what is happening.


"""

# Generator for real data, random sampling strategy
x_real_g=make_batches_random(
    x=x_real,
    y=None,
    batch_size=bs//2,
    stop_after_epoch=False
)

# Validation data
val_data_real=np.take(  # Real samples
    a=x_real,
    indices=np.random.randint(
        low=0,
        high=len(x_real), 
        size=256
    ),
    axis=0
)
val_params_fake = sample_latent_factors(256)


for epoch in range(0, PARAMS['epochs']):  
    # Batch
    with tqdm_notebook(total=n_batches, unit='batch', desc="Epoch: {} ".format(epoch)) as pbar:
        for n_batch in range(n_batches):
            for i in range(PARAMS['disc_gen_ratio']):
                # DISCRIMINATOR
                _, disc_x_real, _ = x_real_g.__next__()
                disc_x = tf.concat([
                    disc_x_real,
                    generator(sample_latent_factors(bs//2))  # Fake data
                ], axis=0)
                discriminator.trainable=True
                disc_b_loss = train_step(
                    model=discriminator,
                    optimizer=disc_opt,
                    fn_loss=fn_loss,
                    x_train=disc_x,
                    y_train=disc_y
                )    
            # GENERATOR
            discriminator.trainable=False
            gen_b_loss = train_step(
                model=gan,  # Latent params -> Generator -> image -> Discriminator -> Probability
                optimizer=gen_opt,
                fn_loss=fn_loss,
                x_train=sample_latent_factors(bs),
                y_train=gen_y
            )

            #Update progress bar
            pbar.set_postfix(disc_loss=disc_b_loss.numpy(), gen_loss=gen_b_loss.numpy())
            pbar.update(1)

        
    # EPOCH METRICS
    val_data_fake = generator(val_params_fake)
    val_metrics = evaluate(generator, discriminator, val_data_real, val_data_fake)
    for m in ['disc_loss','acc','gen_loss','fpr','cm']:
        print("{}\t{}".format(m, val_metrics[m]))
    
    with disc_summary_writer.as_default():
        tf.summary.scalar('loss', val_metrics["disc_loss"], step=epoch)
        tf.summary.scalar('acc', val_metrics["acc"], step=epoch)
    with gen_summary_writer.as_default():
        tf.summary.scalar('loss', val_metrics["gen_loss"], step=epoch)
        tf.summary.scalar('fpr', val_metrics["fpr"], step=epoch)
        tf.summary.image("Training data", val_data_fake, step=epoch, max_outputs=25)
    
    if epoch%2 == 0:
        plot_generated_images(epoch, val_data_fake.numpy()[:100], img_out_dir) # Output generated images   

### Generate a gif with the progression

In [None]:
!convert -resize 25% -delay 100 -loop 0 {log_dir}/imgs/*.png img/gan/evolution.gif

![](img/gan/evolution.gif)

## Generate fake samples

In [None]:
sample = generator(sample_latent_factors(1))[0].numpy().reshape(32,32)
print("Disc prob: {}".format(discriminator(sample.reshape(-1,32,32,1)).numpy()[0][0]))

plt.imshow(
    sample, 
    cmap='gray'
)

## Save the models

In [None]:
# Save
tf.saved_model.save(generator, log_dir + '/saved_generator')
tf.saved_model.save(discriminator, log_dir + '/saved_discriminator')

## Load & use the models

### Generator

In [None]:
# Load
tf_gen = tf.saved_model.load(log_dir + '/saved_generator')

# Serving function
infer_gen = tf_gen.signatures['serving_default']
# Model input
print("Model input: \n\t{}".format(infer_gen.structured_input_signature))
# Model output
print("Model output layer \n\t{}".format(infer_gen.structured_outputs))

# Serving function
generate = lambda x: infer_gen(x)['reshape']

### Discriminator

In [None]:
# Load
tf_disc = tf.saved_model.load(log_dir + '/saved_discriminator')

# Serving function
infer_disc = tf_disc.signatures['serving_default']
# Model input
print("Model input: \n\t{}".format(infer_disc.structured_input_signature))
# Model output
print("Model output layer \n\t{}".format(infer_disc.structured_outputs))

# Serving function
discriminate = lambda x: infer_disc(x)['F5']

### Test

In [None]:
samples = generate(sample_latent_factors(1))

print("Disc prob: {}".format(discriminate(samples).numpy()[0][0]))

plt.imshow(
    samples[0].numpy().reshape(32,32), 
    cmap='gray'
)

In [None]:
# Problem using keras
discriminator.to_json()

## References

1. <a name="bib-web-gantf"></a>[Building a simple Generative Adversarial Network (GAN) using TensorFlow](https://blog.paperspace.com/implementing-gans-in-tensorflow/)
2. <a name="bib-web-adversarialtf"></a>[Generative Adversarial Nets in TensorFlow](https://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/)
3. <a name="bib-web-gankeras"></a>[Generative Adversarial Network(GAN) using Keras](https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3)
4. <a name="bib-web-poolstride"></a>[Pooling VS Striding - Striving for Simplicity: The All Convolutional Net](https://arxiv.org/abs/1412.6806)
5. <a name="bib-web-collapse1"></a>[Mode collapse: GAN — Why it is so hard to train Generative Adversarial Networks!](https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b)
6. <a name="bib-web-collapse2"></a>[Mode collapse: What does it mean if all produced images of a GAN look the same?](https://www.quora.com/What-does-it-mean-if-all-produced-images-of-a-GAN-look-the-same)
7. <a name="bib-vid-gans"></a>[NIPS 2016 - Generative Adversarial Networks - Ian Goodfellow](https://www.youtube.com/watch?v=AJVyzd0rqdc)
  1. [On divergence](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=52m10s)
  1. [On labeled/conditioned GANS](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=1h09m50s)
  1. [On mode collapse](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=1h31m53s)
  1. [One sided label smoothing](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=1h11m34s)
  1. [Question: GANs vs VAEs](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=37m10s)
  1. [Question: Sampling distributions uniform VS Norma](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=37m57s)
  1. [Question: mode collapse/same sample](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=35m47s)
8.  <a name="bib-web-tbstarted"></a>[Get started with TensorBoard](https://www.tensorflow.org/tensorboard/r2/get_started#using_tensorboard_with_other_methods)
9. <a name="bib-web-codegan1"></a>[Github: PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py)
10. [TQDM (status bar library)](https://tqdm.github.io/)
11. [Inside TensorFlow: Summaries and TensorBoard](https://www.youtube.com/watch?v=OI4cskHUslQ)
 