In [1]:
from pathlib import Path
import matplotlib.pyplot as plt

import numpy as np

import tensorflow as tf
import keras
from keras import backend as K
from keras.layers import Conv2D, Dense, ReLU, BatchNormalization, Input, Flatten, Concatenate, Reshape,\
Activation, Conv2DTranspose, LeakyReLU, Dropout
from keras import Model
from keras.optimizers import Adam, SGD, RMSprop
from keras.metrics import Mean

from tensorflow.keras.utils import to_categorical, image_dataset_from_directory
from keras.losses import BinaryCrossentropy

from tensorflow.keras.preprocessing.image import ImageDataGenerator

## Load CelebA data set

In [2]:
data_set = 'celeba'
IMG_SHAPE = (128, 128, 3)
batch_size = 32

In [3]:
def celeba_transform(x):
    return ((tf.cast(x, tf.float32) - 127.5) / 127.5)

In [5]:
train_data = image_dataset_from_directory('../data/input/celeba', 
                image_size=IMG_SHAPE[:-1], 
                batch_size=None,
                labels=None,
                smart_resize=True
             )\
             .map(lambda x: celeba_transform(x))\
             .batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)\
             .repeat()

Found 202599 files belonging to 1 classes.


## Define Conditional GAN model class

In [59]:
class CGAN(keras.Model):
    def __init__(self, n_critic, latent_dim=100, kernel_size=5):
        super().__init__()
        self.n_critic = n_critic
        self.z_dim = latent_dim
        self.kernel_size = kernel_size
        self.discriminator = self._build_discriminator()
        self.generator = self._build_generator()
        
    def compile_model(self, discriminator_optimizer, generator_optimizer, loss_fn):
        super().compile(run_eagerly=False)
        self.discriminator_optimizer = discriminator_optimizer
        self.generator_optimizer = generator_optimizer
        self.loss_fn = loss_fn
        
        self.discriminator_loss_metric = Mean(name='d_loss')
        self.generator_loss_metric = Mean(name='g_loss')
        
    def _build_generator(self):
        img = Input((self.z_dim,))
        z = Dense(8*8*512)(img)
        # z = BatchNormalization()(z)
        z = LeakyReLU()(z)
        
        z = Reshape((8, 8, 512))(z)
        
        num_filters = [128, 128, 64, 32]
        for f in num_filters:
            z = Conv2DTranspose(
                filters=f,
                kernel_size=self.kernel_size,
                strides=2,
                padding='same',
            )(z)
            # z = BatchNormalization()(z)
            z = LeakyReLU(alpha=0.2)(z)
        
        output = Conv2D(
            filters=3,
            kernel_size=5,
            strides=1,
            padding='same',
            activation='tanh',
        )(z)

        return Model(inputs=img, outputs=output)
    
    def _build_discriminator(self):
        img = Input(IMG_SHAPE)
        
        z = Conv2D(
            filters=64,
            kernel_size=self.kernel_size,
            strides=2,
            padding='same',
        )(img)
        z = LeakyReLU()(z)
        
        z = Conv2D(
            filters=128,
            kernel_size=self.kernel_size,
            strides=2,
            padding='same',
        )(z)
        z = LeakyReLU()(z)
        
        z = Conv2D(
            filters=128,
            kernel_size=self.kernel_size,
            strides=2,
            padding='same',
        )(z)
        z = LeakyReLU()(z)
        
#         z = Conv2D(
#             filters=256,
#             kernel_size=self.kernel_size,
#             strides=2,
#             padding='same',
#         )(z)
#         z = LeakyReLU()(z)
        
#         z = Conv2D(
#             filters=512,
#             kernel_size=5,
#             strides=2,
#             padding='same',
#         )(z)
#         z = LeakyReLU()(z)
        
        z = Flatten()(z)
        z = Dropout(0.3)(z)
        output = Dense(1)(z)
        
        return Model(inputs=img, outputs=output)
    
    def train_step(self, batch_data):
        img = batch_data
        batch_size = batch_data.shape[0]
        real_target_labels = tf.ones((batch_size, 1))
        fake_target_labels = -1.0 * tf.ones((batch_size, 1))
        
        # Discriminator Loss
        for _ in range(self.n_critic):
            with tf.GradientTape() as tape:
                z = tf.random.normal((batch_size, self.z_dim))
                fake_img = self.generator(z)
                combined_img = tf.concat([img, fake_img], axis=0)
                labels = tf.concat([real_target_labels, fake_target_labels], axis=0)
                predictions = self.discriminator(combined_img)
                d_loss = 0.5*self.loss_fn(labels, predictions)
            
                d_loss = d_loss + 2*self.gp_loss(img, fake_img)
                
            gradients = tape.gradient(d_loss, self.discriminator.trainable_weights)
            self.discriminator_optimizer.apply_gradients(zip(gradients, self.discriminator.trainable_weights))

        with tf.GradientTape() as tape:
            # Generator Loss
            z = tf.random.normal((batch_size, self.z_dim))
            fake_img = self.generator(z)
            fake_predictions = self.discriminator(fake_img)
            g_loss = self.loss_fn(real_target_labels, fake_predictions)
            
        gradients = tape.gradient(g_loss, self.generator.trainable_weights)
        self.generator_optimizer.apply_gradients(zip(gradients, self.generator.trainable_weights))
        
        self.discriminator_loss_metric.update_state(d_loss)
        self.generator_loss_metric.update_state(g_loss)
        
        return {m.name: m.result() for m in self.metrics}
    
    @property
    def metrics(self):
        return [self.discriminator_loss_metric, self.generator_loss_metric]
    
    def gp_loss(self, real_img, fake_img):
        try:
            alpha = tf.random.uniform(fake_img.shape, 0.0, 1.0)
            interpolated_img = alpha * real_img + (1-alpha) * fake_img

            with tf.GradientTape() as tape:
                tape.watch(interpolated_img)
                predictions = self.discriminator(interpolated_img)

            gradients = tape.gradient(predictions, [interpolated_img])
            gradient_l2_norm = K.sqrt(K.sum(K.square(gradients), axis=[1, 2, 3]))

            gradient_penalty = K.mean(K.square(1.0-gradient_l2_norm))
            return gradient_penalty
        except Exception as e:
            print(real_img.shape, fake_img.shape)
            raise e

## Define Callback

In [60]:
class CustomCallback(keras.callbacks.Callback):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.z_dim = latent_dim
        self.z = tf.random.normal((10, 10, self.z_dim))
        self.labels = tf.keras.utils.to_categorical(np.arange(10), num_classes=10, dtype=np.float32)
        
    def on_train_begin(self, logs):
        self.d_losses = []
        self.g_losses = []
        
    def on_batch_end(self, batch, logs):
        self.d_losses.append(logs['d_loss'])
        self.g_losses.append(logs['g_loss'])
        
    def on_epoch_end(self, epoch, logs):
        fig, ax = plt.subplots(10, 10, figsize=(20, 20))
        for i in range(10):
            generated_faces = self.model.generator(self.z[i], training=False)
            for j in range(10):
                ax[i][j].matshow(np.clip(generated_faces[j]*127.5 + 127.5, 0, 255).astype(int), cmap='viridis')
                ax[i][j].axis('off')
        fig.savefig(f"../data/tmp/celeba/generated_faces_{epoch:03}.png")
        plt.close()
        
        fig, ax = plt.subplots(1, 1, figsize=(8, 5))
        ax.plot(self.d_losses, label='Discriminator loss')
        ax.plot(self.g_losses, label='Generator loss')
        ax.legend()
        fig.savefig(f"../data/tmp/celeba/training_losses.png")
        plt.close()
        
    def on_train_end(self, logs):
        fig, ax = plt.subplots(1, 1, figsize=(8, 5))
        ax.plot(self.d_losses, label='Discriminator loss')
        ax.plot(self.g_losses, label='Generator loss')
        ax.legend()
        fig.savefig(f"../data/tmp/celeba/training_losses.png")
        plt.close()

In [61]:
callback = CustomCallback(latent_dim=128)

In [62]:
def wasserstein_loss(y_true, y_pred):
    return - K.mean(y_true * y_pred)

## Define parameters

In [63]:
epochs = 20
generator_lr = 1e-4
discriminator_lr = 5e-5

generator_optimizer = Adam(learning_rate=generator_lr)
discriminator_optimizer = Adam(learning_rate=discriminator_lr)

loss_fn = wasserstein_loss

## Define a Data Iterator

## Train CGAN model

In [64]:
latent_dim = 128
n_critic = 3

In [65]:
cgan = CGAN(n_critic=n_critic, latent_dim=latent_dim, kernel_size=4)

In [66]:
cgan.compile_model(discriminator_optimizer, generator_optimizer, loss_fn)

## GAN Training Hacks (https://github.com/soumith/ganhacks)

### Tracking failures early
- D loss goes to 0: failure mode
- check norms of gradients: if they are over 100 things are screwing up
- when things are working, D loss has low variance and goes down over time vs having huge variance and spiking
- if loss of generator steadily decreases, then it's fooling D with garbage (says martin)

In [None]:
try:
    cgan.fit(train_data,
             epochs=epochs,
             steps_per_epoch=1000,
             callbacks=[callback],
             use_multiprocessing = True
            )
except Exception as e:
    print(e)
    callback.on_train_end(None)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20

In [None]:
fig, ax = plt.subplots(10, 10, figsize=(20, 20))
for i in range(10):
    z = tf.random.normal((10, latent_dim))
    generated_faces = cgan.generator(z)
    for j in range(10):
        ax[i][j].imshow(np.clip(generated_faces[j]*127.5 + 127.5, 0, 255).astype(int))
        ax[i][j].axis('off')
fig.show()

In [None]:
def generate_model_summary():
    print(f"celeba_lr_g_{generator_lr}_d_{discriminator_lr}_epochs_{epochs}_n_critic_{n_critic}_bs_{batch_size}")
    print(f"Generator learning rate: {generator_lr}")
    print(f"Discriminator learning rate: {discriminator_lr}")
    print()
    print(f"Number of epochs: {epochs}")
    print(f"Batch size: {batch_size}")
    print()
    print(f"n_critic: {n_critic}")
    print()
    print(f"Generator summary:")
    print(cgan.generator.summary())
    print()
    print(f"Discriminator summary:")
    print(cgan.discriminator.summary())
    print()

In [None]:
generate_model_summary()