# WGAN and WGAN-GP

## All Required Libs and Parameters

In [1]:
# All required libs are imporeted in this cell
import os, glob, multiprocessing
import tensorflow as tf
import  numpy as np
from PIL import Image
from tensorflow.keras.layers import BatchNormalization, Conv2D, Dense, Flatten, Reshape, Conv2DTranspose, ReLU, LeakyReLU, Activation
from tensorflow.keras import Sequential, optimizers, Input, layers
from tensorflow import keras
print(tf.__version__)

z_dim = 100
epochs = 3000000
batch_size = 128
learning_rate = 0.0002
is_training = True

2.2.0-dev20200315


## All Required Functions/Methods

In [2]:
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):

    # @tf.function
    def _map_fn(img):
        img = tf.image.resize(img, [resize, resize])
        # img = tf.image.random_crop(img,[resize, resize])
        # img = tf.image.random_flip_left_right(img)
        # img = tf.image.random_flip_up_down(img)
        img = tf.clip_by_value(img, 0, 255)
        img = img / 127.5 - 1 #-1~1
        return img

    dataset = disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) // batch_size

    return dataset, img_shape, len_dataset


def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=None):
    # set defaults
    if n_map_threads is None:
        n_map_threads = multiprocessing.cpu_count()
    if shuffle and shuffle_buffer_size is None:
        shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048

    # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
    if shuffle:
        dataset = dataset.shuffle(shuffle_buffer_size)

    if not filter_after_map:
        if filter_fn:
            dataset = dataset.filter(filter_fn)

        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

    else:  # [*] this is slower
        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

        if filter_fn:
            dataset = dataset.filter(filter_fn)

    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

    dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)

    return dataset


def memory_data_batch_dataset(memory_data,
                              batch_size,
                              drop_remainder=True,
                              n_prefetch_batch=1,
                              filter_fn=None,
                              map_fn=None,
                              n_map_threads=None,
                              filter_after_map=False,
                              shuffle=True,
                              shuffle_buffer_size=None,
                              repeat=None):
    """Batch dataset of memory data.
    Parameters
    ----------
    memory_data : nested structure of tensors/ndarrays/lists
    """
    dataset = tf.data.Dataset.from_tensor_slices(memory_data)
    dataset = batch_dataset(dataset,
                            batch_size,
                            drop_remainder=drop_remainder,
                            n_prefetch_batch=n_prefetch_batch,
                            filter_fn=filter_fn,
                            map_fn=map_fn,
                            n_map_threads=n_map_threads,
                            filter_after_map=filter_after_map,
                            shuffle=shuffle,
                            shuffle_buffer_size=shuffle_buffer_size,
                            repeat=repeat)
    return dataset


def disk_image_batch_dataset(img_paths,
                             batch_size,
                             labels=None,
                             drop_remainder=True,
                             n_prefetch_batch=1,
                             filter_fn=None,
                             map_fn=None,
                             n_map_threads=None,
                             filter_after_map=False,
                             shuffle=True,
                             shuffle_buffer_size=None,
                             repeat=None):
    """Batch dataset of disk image for PNG and JPEG.
    Parameters
    ----------
        img_paths : 1d-tensor/ndarray/list of str
        labels : nested structure of tensors/ndarrays/lists
    """
    if labels is None:
        memory_data = img_paths
    else:
        memory_data = (img_paths, labels)

    def parse_fn(path, *label):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=3)  # fix channels to 3
        return (img,) + label

    if map_fn:  # fuse `map_fn` and `parse_fn`
        def map_fn_(*args):
            return map_fn(*parse_fn(*args))
    else:
        map_fn_ = parse_fn

    dataset = memory_data_batch_dataset(memory_data,
                                        batch_size,
                                        drop_remainder=drop_remainder,
                                        n_prefetch_batch=n_prefetch_batch,
                                        filter_fn=filter_fn,
                                        map_fn=map_fn_,
                                        n_map_threads=n_map_threads,
                                        filter_after_map=filter_after_map,
                                        shuffle=shuffle,
                                        shuffle_buffer_size=shuffle_buffer_size,
                                        repeat=repeat)

    return dataset

def get_random_z(z_dim, batch_size):
    return tf.random.uniform([batch_size, z_dim], minval=-1, maxval=1)


 
def save_result(val_out, val_block_size, image_path, color_mode):
    def preprocess(img):
        img = ((img + 1.0) * 127.5).astype(np.uint8)
        # img = img.astype(np.uint8)
        return img

    preprocesed = preprocess(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocesed[b, :, :, :]
        else:
            single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)

        # concat image row to final_image
        if (b+1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    #toimage(final_image).save(image_path)
    #print(image_path)
    Image.fromarray(final_image).save(image_path)
img_path = glob.glob('./faces/*.jpg')
print("Num of Files:",len(img_path))

dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize = 64)
print(dataset, img_shape)
sample = next(iter(dataset))
print(sample.shape, tf.reduce_max(sample).numpy(), tf.reduce_min(sample).numpy())
dataset = dataset.repeat(100)
db_iter = iter(dataset)

Num of Files: 51223
<PrefetchDataset shapes: (128, 64, 64, 3), types: tf.float32> (64, 64, 3)
(128, 64, 64, 3) 1.0 -1.0


## Model Definitions and Test Code
### SubClassing Mode

In [4]:
# WGAN can re-use DCGAN with minor modifications:
# No BatchNormalization, No sigmoid output for discriminator
class Generator(keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        self.s = 2
        self.k = 4
        self.n_f = 1024
        
        self.dense1 = Dense(self.s * self.s * self.n_f)
        self.reshape1 = Reshape(target_shape = (self.s, self.s, self.n_f))
        self.conv1 = Conv2DTranspose(512, self.k, 2, 'same')
        self.conv2 = Conv2DTranspose(256, self.k, 2, 'same')
        self.conv3 = Conv2DTranspose(128, self.k, 2, 'same')
        self.conv4 = Conv2DTranspose(64, self.k, 2, 'same')
        self.conv5 = Conv2DTranspose(3, self.k, 2, 'same')
        
    def call(self, inputs, training=None):
        x = inputs
        x = self.dense1(x)
        x = tf.nn.leaky_relu(x)
        x = self.reshape1(x)
        
        x = tf.nn.leaky_relu(self.conv1(x))
        x = tf.nn.leaky_relu(self.conv2(x))
        x = tf.nn.leaky_relu(self.conv3(x))
        x = tf.nn.leaky_relu(self.conv4(x))
        x = tf.tanh(self.conv5(x))
        
        return x
class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.s = 4
        self.k = 4
        
        self.conv1 = Conv2D(64, self.k, 2, 'same')
        self.conv2 = Conv2D(128, self.k, 2, 'same')
        self.conv3 = Conv2D(256, self.k, 2, 'same')
        self.conv4 = Conv2D(512, self.k, 2, 'same')
        self.flatten = Flatten()
        self.dense = Dense(1)
    
    def call(self, inputs, training = None):
        x = inputs
        x = tf.nn.leaky_relu(self.conv1(x))
        x = tf.nn.leaky_relu(self.conv2(x))
        x = tf.nn.leaky_relu(self.conv3(x))
        x = tf.nn.leaky_relu(self.conv4(x))
        x = self.dense(self.flatten(x))
        
        return x
    
# Generator_s and Discriminator_s are lite versions of DCGAN
class Generator_s(keras.Model):

    def __init__(self):
        super(Generator_s, self).__init__()

        # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
        self.fc = layers.Dense(3*3*512)

        self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

    def call(self, inputs, training=None):
        # [z, 100] => [z, 3*3*512]
        x = self.fc(inputs)
        x = tf.reshape(x, [-1, 3, 3, 512])
        x = tf.nn.leaky_relu(x)

        #
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = self.conv3(x)
        x = tf.tanh(x)

        return x


class Discriminator_s(keras.Model):

    def __init__(self):
        super(Discriminator_s, self).__init__()

        # [b, 64, 64, 3] => [b, 1]
        self.conv1 = layers.Conv2D(64, 5, 3, 'valid')

        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()

        # [b, h, w ,c] => [b, -1]
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)


    def call(self, inputs, training=None):

        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        # [b, h, w, c] => [b, -1]
        x = self.flatten(x)
        # [b, -1] => [b, 1]
        logits = self.fc(x)

        return logits

def gradient_penalty(discriminator, batch_x, fake_image):

    batchsz = batch_x.shape[0]

    # [b, h, w, c]
    t = tf.random.uniform([batchsz, 1, 1, 1])
    # [b, 1, 1, 1] => [b, h, w, c]
    t = tf.broadcast_to(t, batch_x.shape)

    interplate = t * batch_x + (1 - t) * fake_image

    with tf.GradientTape() as tape:
        tape.watch([interplate])
        d_interplote_logits = discriminator(interplate, training=True)
    grads = tape.gradient(d_interplote_logits, interplate)

    # grads:[b, h, w, c] => [b, -1]
    grads = tf.reshape(grads, [grads.shape[0], -1])
    gp = tf.norm(grads, axis=1) #[b]
    gp = tf.reduce_mean( (gp-1)**2 )

    return gp

# WGAN has different loss functions for generator and discriminator. 
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    # 1. treat real image as real
    # 2. treat generated image as fake
    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    d_real_logits = discriminator(batch_x, is_training)
    loss = tf.reduce_mean(d_fake_logits) - tf.reduce_mean(d_real_logits)
    return loss


def g_loss_fn(generator, discriminator, batch_z, is_training):
    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    loss = -tf.reduce_mean(d_fake_logits)
    return loss



In [None]:
# WGAN
def main():
    
    with tf.device('/GPU:0'):
        generator = Generator_s()
        generator.build(input_shape = (None, z_dim))
        # generator.summary()

        discriminator = Discriminator_s()
        discriminator.build(input_shape = (None, 64, 64, 3))
        # discriminator.summary()

    # Not to use momentum-based optimizers for WGAN, such as Adam. In contrast, use RMSProp or SGD
        # g_opt = optimizers.Adam(learning_rate = learning_rate, beta_1=0.5)
        # d_opt = optimizers.Adam(learning_rate = learning_rate, beta_1=0.5)
        g_opt = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
        d_opt = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
        #g_opt = optimizers.RMSProp

        d_losses, g_losses = [], []

        for epoch in range(epochs):
            for _ in range(5):
                batch_z = get_random_z(z_dim, batch_size)
                batch_x = next(db_iter)

                with tf.GradientTape() as d_tape:
                    #fake_images = generator(batch_z, training = True)
                    #fake_outputs = discriminator(fake_images, training = True)
                    #real_outputs = discriminator(batch_x, training = True)
                    d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
                    #d_loss = tf.reduce_mean(fake_outputs) - tf.reduce_mean(real_outputs)

                # One important difference between DCGAN and WGAN is the clipping process
                # that clip the weights of Discriminator into a range (e.g. -0.01 - 0.01)
                # every iteration.
                [p.assign(tf.clip_by_value(p,-0.01,0.01)) for p in discriminator.trainable_variables]
                d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
                d_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))

                batch_zz = get_random_z(z_dim, batch_size)

                with tf.GradientTape() as g_tape:
                    #fake_outputs = discriminator(generator(batch_zz, training=True), training=True)
                    #g_loss = -tf.reduce_mean(fake_outputs)
                    g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)

                g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
                g_opt.apply_gradients(zip(g_grads, generator.trainable_variables))

            if epoch %100 ==0:
                print("Epoch:", epoch, "D-loss:", float(d_loss), "G-loss:", float(g_loss))            
                z = tf.random.uniform([100, z_dim], minval=-1, maxval=1)
                generated_images = generator(z, training = False)
                img_path = os.path.join('gan_images', 'wgan-dcgan-%d.png'%epoch)
                save_result(generated_images.numpy(), 10, img_path, color_mode='P')

                d_losses.append(float(d_loss))
                g_losses.append(float(g_loss))

            if epoch % 10000 == 1:
                generator.save_weights('generator.ckpt')
                discriminator.save_weights('discriminator.ckpt')

if __name__ == '__main__':
    main()
        

In [None]:
# WGAN-GP (Gradient Penalty)
# GP aims to constrain the weights of Discriminator to be close to 1
# The most parts of WGAN-GP are similar with WGAN.
# Differences: No clipping, but a penalty term in D's loss

def gradient_penalty(discriminator, real, fake):
    batch_sz = real.shape[0]
    t = tf.random.uniform([batch_sz, 1, 1, 1])
    t = tf.broadcast_to(t, real.shape)
    
    interplate = t * real + (1 - t) * fake
    
    with tf.GradientTape() as tape:
        tape.watch([interplate])
        d_interplate_logits = discriminator(interplate, True)
    grads = tape.gradient(d_interplate_logits, interplate)
    
    grads = tf.reshape(grads, [grads.shape[0], -1])
    gp = tf.norm(grads, axis = 1)
    gp = tf.reduce_mean((gp - 1.)**2)
    
    return gp

def gradient_penalty2(discriminator, batch_x, fake_image):

    batchsz = batch_x.shape[0]

    # [b, h, w, c]
    t = tf.random.uniform([batchsz, 1, 1, 1])
    # [b, 1, 1, 1] => [b, h, w, c]
    t = tf.broadcast_to(t, batch_x.shape)

    interplate = t * batch_x + (1 - t) * fake_image

    with tf.GradientTape() as tape:
        tape.watch([interplate])
        d_interplote_logits = discriminator(interplate, training=True)
    grads = tape.gradient(d_interplote_logits, interplate)

    # grads:[b, h, w, c] => [b, -1]
    grads = tf.reshape(grads, [grads.shape[0], -1])
    gp = tf.norm(grads, axis=1) #[b]
    gp = tf.reduce_mean( (gp-1)**2 )
    
    print(gp)

    return gp

# Override d_loss_fn(.)
def d_loss_gp_fn(g, d, z, real, l, is_training):
    fake_images = g(z, is_training)
    fake_outputs = d(fake_images, is_training)
    real_outputs = d(real, is_training)
    
    gp = gradient_penalty(d, real, fake_images)
    
    loss = tf.reduce_mean(fake_outputs) - tf.reduce_mean(real_outputs) + 10. * gp
    
    return loss, gp
    
import tensorflow as tf
def main():
    l = 10
    generator = Generator_s()
    generator.build(input_shape=(None, z_dim))

    discriminator = Discriminator_s()
    discriminator.build(input_shape=(None, 64, 64, 3))

    g_opt = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
    d_opt = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

    for epoch in range(epochs):
        for _ in range(5):
            batch_z = get_random_z(z_dim, batch_size)
            batch_x = next(db_iter)

            with tf.GradientTape() as d_tape:
                d_loss, gp = d_loss_gp_fn(generator, discriminator, batch_z, batch_x, l, is_training)
            d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
            d_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))

        batch_zz = get_random_z(z_dim, batch_size)
        with tf.GradientTape() as g_tape:
            g_loss = g_loss_fn(generator, discriminator, batch_zz, is_training)
        g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
        g_opt.apply_gradients(zip(g_grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss), 'gp:', float(gp))
            z = tf.random.normal([100, z_dim])
            fake_image = generator(z, training=False)
            img_path = os.path.join('gan_images', 'dc-wgan-gp-%d.png'%epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')
        
if __name__ == "__main__":
    main()

0 d-loss: 0.04687626659870148 g-loss: 0.5297729969024658 gp: 0.007219888269901276
100 d-loss: -4.183963775634766 g-loss: 0.7209520936012268 gp: 0.040356747806072235
