# WGAN-GP

# Import TensorFlow 2.x.

In [0]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
tf.random.set_seed(7)

import tensorflow.keras.layers as layers
import tensorflow.keras.models as models

import numpy as np
np.random.seed(7)

import matplotlib.pyplot as plot

print(tf.__version__)

# Set the root directory.

In [0]:
import os

root_dir = '/content/'
os.chdir(root_dir)

!ls -al

In [0]:
class Tanh(layers.Layer):
    def __init__(self):
        super(Tanh, self).__init__()

    def call(self, inputs):
        return tf.keras.activations.tanh(inputs)

In [0]:
class Conv2D(layers.Layer):
  def __init__(self, filters, kernel_size, strides=2):
    super(Conv2D, self).__init__()
    self.conv_op = layers.Conv2D(filters = filters,
                                 kernel_size = kernel_size, 
                                 strides = strides,
                                 padding = 'same', 
                                 kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02),
                                 use_bias=True, 
                                 bias_initializer=tf.keras.initializers.Constant(value=0.0))
  def call(self, inputs):
    return self.conv_op(inputs)

In [0]:
class BatchNorm(layers.Layer):
    def __init__(self, is_training=False):
        super(BatchNorm, self).__init__()
        self.bn = tf.keras.layers.BatchNormalization(epsilon=1e-5,
                                                     momentum=0.9,
                                                     scale=True,
                                                     trainable=is_training)

    def call(self, inputs, training):
        x = self.bn(inputs, training=training)
        return x

In [0]:
class DenseLayer(layers.Layer):
    def __init__(self, hidden_n, is_input=False):
        super(DenseLayer, self).__init__()

        self.fc_op = layers.Dense(hidden_n,
                                  kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
                                  bias_initializer=tf.keras.initializers.Constant(value=0.0))

    def call(self, inputs):
        x = self.fc_op(inputs)

        return x

In [0]:
class UpConv2D(layers.Layer):
    def __init__(self, filters, kernel_size, strides):
        super(UpConv2D, self).__init__()
        self.up_conv_op = layers.Conv2DTranspose(filters,
                                                 kernel_size=kernel_size,
                                                 strides=strides,
                                                 padding='same',
                                                 kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
                                                 use_bias=True,
                                                 bias_initializer=tf.keras.initializers.Constant(value=0.0))

    def call(self, inputs):
        x = self.up_conv_op(inputs)
        return x

In [0]:
def create_discriminator_model(is_training):
  model = tf.keras.Sequential()
  model.add(Conv2D(64,4,2))
  model.add(layers.LeakyReLU(alpha=0.2))

  model.add(Conv2D(128,4,2))
  model.add(BatchNorm(is_training=is_training))
  model.add(layers.LeakyReLU(alpha=0.2))

  model.add(layers.Flatten())

  model.add(DenseLayer(1024))
  model.add(BatchNorm(is_training=is_training))
  model.add(layers.LeakyReLU(alpha=0.2))
        
  model.add(DenseLayer(1))
  return model

In [0]:
def create_generator_model(z_dim, is_training):
  model = tf.keras.Sequential()

  model.add(DenseLayer(1024))
  model.add(BatchNorm(is_training=is_training))
  model.add(layers.ReLU())

  model.add(DenseLayer(128*7*7))
  model.add(BatchNorm(is_training=is_training))
  model.add(layers.ReLU())

  model.add(layers.Reshape((7,7,128)))

  model.add(UpConv2D(64,4,2))
  model.add(BatchNorm(is_training=is_training))
  model.add(layers.ReLU())

  model.add(UpConv2D(1,4,2))
  model.add(Tanh())
  return model

In [0]:
batch_size = 64
learnning_rate = 2e-4
z_dim = 62
lam = 10.
epochs =20

In [0]:
generator = create_generator_model(z_dim, is_training=True)
discriminator = create_discriminator_model(is_training=True)

In [0]:
generator_optimizer = tf.keras.optimizers.RMSprop(lr=5*learnning_rate)
discriminator_optimizer = tf.keras.optimizers.RMSprop(lr=learnning_rate)

In [0]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
BUFFER_SIZE=train_images.shape[0]

In [0]:
train_images = (train_images-127.5)/127.5

In [0]:
train_labels=tf.one_hot(train_labels,depth=10)

In [0]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(batch_size,drop_remainder=True)

In [0]:
start_epoch=0

In [0]:
def train_one_step(batch_images):
        batch_z = np.random.uniform(-1, 1,[batch_size, z_dim]).astype(np.float32)
        real_images = batch_images
        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
            fake_images = generator(batch_z, training=True)
            fake_predictions = discriminator(fake_images, training=True)
            real_predictions = discriminator(real_images, training=True)

            discriminator_loss = tf.reduce_mean(-real_predictions) + tf.reduce_mean(fake_predictions)
            generator_loss = tf.reduce_mean(-fake_predictions)
 
            with tf.GradientTape() as gp_tape:
                alpha = tf.random.uniform([batch_size],0.,1.,dtype=tf.float32)
                alpha = tf.reshape(alpha,(-1,1,1,1))
                sample_images = real_images + alpha * (fake_images - real_images)

                gp_tape.watch(sample_images)
                sample_predictions = discriminator(sample_images, training=False)
                
            gradients = gp_tape.gradient(sample_predictions,sample_images)                
            grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3]))
            gradient_penalty = tf.reduce_mean((grad_l2-1) ** 2)            
            discriminator_loss +=lam*gradient_penalty 

        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, discriminator.trainable_variables)
        generator_gradients = generator_tape.gradient(generator_loss, generator.trainable_variables)

        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
        generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))

        return(discriminator_loss, generator_loss)

In [0]:
def generate_images():
      sample_z = tf.random.uniform(minval=-1,maxval=1, shape=(batch_size, z_dim), dtype=tf.dtypes.float32)
      generated_images = generator(sample_z, training=False)
      generated_images = generated_images.numpy()
      generated_images = generated_images.reshape(generated_images.shape[0], 28, 28).astype('float32')
      generated_images = (generated_images + 1.) / 2. * 255.
      generated_images = generated_images.astype('uint8')
      plot.imshow(generated_images[0])
      plot.show()

In [0]:
for epoch in range(start_epoch, epochs):
  step = 1
  for batch_images, _ in train_dataset:
    discriminator_loss, generator_loss = train_one_step(batch_images)

    if step % 100 == 0:
      generate_images()
      print('discriminator loss', discriminator_loss.numpy(), 'generator loss', generator_loss.numpy())

    step = step + 1