# Adversarial Generative Networks

In [None]:
import gzip
import struct
import numpy as np
import tensorflow as tf
from tqdm import tqdm


import matplotlib.pyplot as plt
import sklearn.metrics
import seaborn as sns


import datetime

In [None]:
assert(int(tf.__version__[0])==2)  # Use TensorFlow 2
sns.set(rc={'figure.figsize':(11.7,8.27)})  # Figure size

%load_ext tensorboard

## Constants

In [None]:
LOG_PATH = 'log-GAN'  # For tensorboard, etc.

DS_PATH = 'data' # Where de MNIST dataset is located

# Hyperparams
PARAMS = {
    'generator': {
        'learning_rate': 0.0002
    },
    'discriminator': {
        'learning_rate': 0.0002
    },
    'latent_factors': 100,
    'epochs': 200,
    'batch_size': 64
}

## Dataset

In [None]:
# Get dataset
class MNIST_Dataset():
    """Based on the tf.data
    # TODO: replace the StandarScaler for a Tensorflow version. See: tf.nn.moments
    """
    def __init__(self, ds_path):
        """ds_path is the folder where the 4 .gz files are located
        """
        self.FILES = {
            'train': ds_path+"/train-images-idx3-ubyte.gz",
            'train_labels': ds_path+"/train-labels-idx1-ubyte.gz",
            'test': ds_path+"/t10k-images-idx3-ubyte.gz",
            'test_labels': ds_path+"/t10k-labels-idx1-ubyte.gz"
        }
        
        
    
    @staticmethod
    def _read_idx(filename):
        """Read a tensor from a file in idx format.
        """
        with gzip.open(filename, 'rb') as fd:
            _, data_type, dims = struct.unpack('>HBB', fd.read(4))
            shape = tuple(
                struct.unpack('>I', fd.read(4))[0]
                for d
                in range(dims)
            )
            return np.frombuffer(fd.read(), dtype=np.uint8).reshape(shape)

        
    @staticmethod
    def _preprocess(x):
        """x: tensor of shape (-1, 28, 28) representing the images.
        """
        n,w,h = x.shape
        # From 28x28 pixels to 32x32
        x_32 = np.pad(
            x,
            pad_width=((0,0),(2,2),(2,2)),
            mode='constant',
            constant_values=0
        )
        #x_scaled = (x_32/128)-1 # Scaled -1,1
        return x_32.reshape(n,32,32,1)
    
    
    def train_data(self): 
        return (
            self._preprocess(self._read_idx(self.FILES['train'])),
            self._read_idx(self.FILES['train_labels'])
        )

In [None]:
ds = MNIST_Dataset(DS_PATH)
x_real, _ = ds.train_data()  # I don't care about labels

## Model architecture

In [None]:
def get_generator(latent_factors):
    """ From noise to plausible examples.
    INPUT: Noise vector of latent factors (-1,1)
    OUTPUT: 32x32x1 (-1,1) grayscale image
    """
    return tf.keras.Sequential(
        name="Generator",
        layers=[
            tf.keras.layers.InputLayer(
                input_shape=(latent_factors)
            ),
            tf.keras.layers.Dense(  
                name="D1",
                units=128,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.BatchNormalization(
                momentum=0.8,
                name="BN1"
            ),
            tf.keras.layers.Dense(  
                name="D2",
                units=256,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.BatchNormalization(
                momentum=0.8,
                name="BN2"
            ),
            tf.keras.layers.Dense(  
                name="D3",
                units=512,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.BatchNormalization(
                momentum=0.8,
                name="BN4"
            ),
            tf.keras.layers.Dense(  
                name="D4",
                units=1024,
                activation=tf.math.tanh
            ),
            # Reshape 120->1x1x120
            tf.keras.layers.Reshape(
                target_shape=(32,32,1)
            )
        ]
    )

In [None]:
def get_discriminator():
    """Output the probability for an example to be real.
    """
    return tf.keras.Sequential(
        name="Discriminator",
        layers=[
            tf.keras.layers.InputLayer(
                input_shape=(32,32,1)
            ),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(  
                name="F3",
                units=512,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.Dense(  
                name="F4",
                units=256,
                activation=tf.keras.layers.LeakyReLU(alpha=0.2)
            ),
            tf.keras.layers.Dense(  
                name="F5",
                units=1,
                activation=tf.math.sigmoid
            )
        ]
    )

In [None]:
def get_gan(generator, discriminator):
    return tf.keras.Sequential(
        name="GAN",
        layers=[
            tf.keras.layers.InputLayer(
                input_shape=(
                    generator.layers[0].input.shape[1]
                )
            ),
            generator,
            discriminator
        ]
    )

In [None]:
generator = get_generator(PARAMS['latent_factors'])
gen_opt = tf.keras.optimizers.Adam()

discriminator = get_discriminator()
disc_opt = tf.keras.optimizers.Adam() 

gan = get_gan(generator, discriminator)

generator.summary()
discriminator.summary()
gan.summary()

## Util Functions

In [None]:
def sample_latent_factors(n):
    """Noise sample function.
    """
    return tf.random.uniform(shape=(n, PARAMS['latent_factors']), minval=-1.0, maxval=1.0)
    #return tf.random.normal(shape=(n, PARAMS['latent_factors']), mean=0.0, stddev=1.0)

In [None]:
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(32,32)):
    z = sample_latent_factors(examples)
    generated_images = generator.predict(z)
    generated_images = generated_images.reshape(examples,32,32)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('generator_{}.png'.format(epoch))
    return None

## Operative

In [None]:
def train_step(model, optimizer, fn_loss, x_train, y_train):
    """Training step for a model.
    """
    with tf.GradientTape() as tape:
        predictions = model(x_train, training=True)
        loss = fn_loss(y_train, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

In [None]:
def fn_loss(y_true, y_pred):
    """Loss function.
    Log loss or cross-entropy.
    """
    return tf.math.reduce_mean(
        tf.keras.losses.binary_crossentropy(
            y_true=y_true,
            y_pred=y_pred,
            label_smoothing = 0  # Not useful since I only need one sided label smoothing
        )
    )

In [None]:
""" Tensorboard writer
"""
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = LOG_PATH+ '/' + current_time + '/train'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)

In [None]:
""" Precalculated data to speedup calculations.
"""
# Real data
n_data = x_real.shape[0]
bs = PARAMS['batch_size']
n_real = n_fake = bs//2
x_real = np.concatenate([
    x_real,
    np.take(  # Data to fill last batch
        a=x_real,
        indices=np.random.randint(
            low=0,
            high=n_data, 
            size=n_data%PARAMS['batch_size']
        ),
        axis=0
    )
])
n_data = x_real.shape[0]
assert(n_data % bs == 0)
batches = ((n_data*2)//bs)  # Batches per epoch (real+fake data)

# Batch labels for the discriminator
disc_y = tf.concat(
    [
        tf.fill((n_real,), 0.90), # Real. One sided label smoothing
        tf.fill((n_fake,), 0.0)   # Fake
    ],
    axis=0
)

# Batch labels for the generator
gen_y = tf.fill((bs,), 0.90)  # Objective is to fool the discriminator, one sided label smoothing


%tensorboard --logdir {train_log_dir}

In [None]:
""" Training loop.
"""

PARAMS['epochs'] = 100


for epoch in range(0, PARAMS['epochs']):
    print("Epoch: {}".format(epoch+1))
    # Discriminator data
    np.random.shuffle(x_real)
    disc_x_fake = generator(sample_latent_factors(n_data))
    # Generator data
    gen_x_fake = sample_latent_factors(n_data*2)  # Same amount of data as for the discriminator
   
    for batch in tqdm(range(batches)):
        # DISCRIMINATOR
        discriminator.trainable=True
        train_step(
            model=discriminator,
            optimizer=disc_opt,
            fn_loss=fn_loss,
            x_train=np.concatenate([
                x_real[batch*(n_real) : (batch+1)*n_real],      # Real samples
                disc_x_fake[batch*(n_fake) : (batch+1)*n_fake]  # Fake samples
            ]),
            y_train=disc_y
        )        
        # GENERATOR
        discriminator.trainable=False
        train_step(
            model=gan,  # Latent params -> Generator -> image -> Discriminator -> Probability
            optimizer=gen_opt,
            fn_loss=fn_loss,
            x_train=gen_x_fake[batch*bs : (batch+1)*bs],
            y_train=gen_y
        )
        
    # EPOCH METRICS
    random_noise = sample_latent_factors(256)
    disc_loss = tf.metrics.binary_accuracy(
        y_true=tf.concat(
            [
                tf.fill((256,), 1.0), # Real
                tf.fill((256,), 0.0)  # Fake
            ],
            axis=0
        ),
        y_pred=tf.reshape(
            tensor=discriminator(
                tf.concat([
                    np.take(  # Real samples
                        a=x_real,
                        indices=np.random.randint(
                            low=0,
                            high=n_data, 
                            size=256
                        ),
                        axis=0
                    ),
                    generator(random_noise) # Fake samples
                ],
                axis=0
            )),
            shape=(-1,)
        )
    )
    gen_loss = tf.metrics.binary_accuracy(
        y_true=tf.fill((256,), 1.0),
        y_pred=tf.reshape(gan(random_noise), (-1,))
    )
    
    print("Discriminator Acc: {}".format(disc_loss))
    print("Generator Acc: {}".format(gen_loss))
    with train_summary_writer.as_default():
        tf.summary.scalar('disc_loss', disc_loss, step=epoch)
        tf.summary.scalar('gen_loss', gen_loss, step=epoch)
    
    if epoch%2 == 0:
        plot_generated_images(epoch+1, generator) # Output generated images
        # TODO: Save model
    
    

## Generate fake samples

In [None]:
sample = generator(sample_latent_factors(1))[0].numpy().reshape(32,32)
print("Disc prob: {}".format(discriminator(sample.reshape(-1,32,32,1)).numpy()[0][0]))

plt.imshow(
    sample, 
    cmap='gray'
)

## Save Model

In [None]:
# Save
with open('saved_models/GAN01/model.json', 'w+') as fd:
    fd.writelines(model.to_json())
model.save_weights('saved_models/01-k/weights')

## References

1. <a name="bib-web-gantf"></a>[Building a simple Generative Adversarial Network (GAN) using TensorFlow](https://blog.paperspace.com/implementing-gans-in-tensorflow/)
2. <a name="bib-web-adversarialtf"></a>[Generative Adversarial Nets in TensorFlow](https://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/)
3. <a name="bib-web-gankeras"></a>[Generative Adversarial Network(GAN) using Keras](https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3)
4. <a name="bib-web-poolstride"></a>[Pooling VS Striding - Striving for Simplicity: The All Convolutional Net](https://arxiv.org/abs/1412.6806)
5. <a name="bib-web-collapse1"></a>[Mode collapse: GAN — Why it is so hard to train Generative Adversarial Networks!](https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b)
6. <a name="bib-web-collapse2"></a>[Mode collapse: What does it mean if all produced images of a GAN look the same?](https://www.quora.com/What-does-it-mean-if-all-produced-images-of-a-GAN-look-the-same)
7. <a name="bib-vid-gans"></a>[NIPS 2016 - Generative Adversarial Networks - Ian Goodfellow](https://www.youtube.com/watch?v=AJVyzd0rqdc)
  1. [On divergence](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=52m10s)
  1. [On labeled/conditioned GANS](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=1h09m50s)
  1. [On mode collapse](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=1h31m53s)
  1. [One sided label smoothing](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=1h11m34s)
  1. [Question: GANs vs VAEs](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=37m10s)
  1. [Question: mode collapse/same sample](https://www.youtube.com/watch?v=AJVyzd0rqdc&t=35m47s)
8.  <a name="bib-web-tbstarted"></a>[Get started with TensorBoard](https://www.tensorflow.org/tensorboard/r2/get_started#using_tensorboard_with_other_methods)
9. <a name="bib-web-codegan1"></a>[Github: PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py)
 