In [1]:
from tensorflow.keras.layers import Activation, Dense, Input, Conv2D, Flatten, Reshape, Conv2DTranspose, LeakyReLU, BatchNormalization
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import Model, load_model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
import numpy as np
import math
import matplotlib.pyplot as plt
import os

In [2]:
def gen_block(x, filters = 128, kernel_size = 5, strides = 2, padding = "same"):
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2DTranspose(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)(x)
    return x

def build_generator(inputs, image_size):
  image_resize = image_size // 4
  x = Dense(image_resize * image_resize * 128)(inputs)
  x = Reshape((image_resize, image_resize, 128))(x)
  x = gen_block(x = x, filters = 128, kernel_size = 5, strides = 2, padding = "same")
  x = gen_block(x = x, filters = 64, kernel_size = 5, strides = 2, padding = "same")
  x = gen_block(x = x, filters = 32, kernel_size = 5, strides = 1, padding = "same")
  x = gen_block(x = x, filters = 1, kernel_size = 5, strides = 1, padding = "same")
  x = Activation('sigmoid')(x)
  generator = Model(inputs, x, name = "generator")
  return generator


def disc_block(x, filters = 64, kernel_size = 5, strides = 2, padding = "same"):
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha = 0.2)(x)
  x = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)(x)
  return x

def build_discriminator(inputs):
   x = inputs
   x = LeakyReLU(alpha = 0.2)(x)
   x = Conv2D(filters = 32, kernel_size = 5, strides = 2, padding = "same")(x)
   x = disc_block(x = x, filters = 64, kernel_size = 5, strides = 2, padding = "same")
   x = disc_block(x = x, filters = 128, kernel_size = 5, strides = 2, padding = "same")
   x = disc_block(x = x, filters = 256, kernel_size = 5, strides = 1, padding = "same")
   x = Flatten()(x)
   x = Dense(1)(x)
   x = Activation("linear")(x)
   discriminator = Model(inputs, x, name = 'discriminator')
   return discriminator

In [7]:
(x_train, _), (_ , _) = mnist.load_data()
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255

model_name = "wagan_mnist"
n_critic = 5
clip_value = 0.01
latent_size = 100 # latent vector of 100-dim
batch_size = 64
lr = 5e-5
train_steps = 3000 # one can train for more number of steps to get better quality
input_shape = (image_size, image_size, 1)

# loss function for the WGAN
def wassertian_loss(y_label, y_pred):
  return -K.mean(y_label * y_pred)

# discriminator model
inputs = Input(shape = input_shape, name = "discriminator_input")
discriminator = build_discriminator(inputs)
discriminator.compile(loss = wassertian_loss,
                      optimizer = RMSprop(learning_rate = lr),
                      metrics = ['accuracy'])
discriminator.summary()

# generator model
input_shape = (latent_size,)
inputs = Input(shape = input_shape, name = 'generator_input')
generator = build_generator(inputs, image_size)
generator.summary()

# adversarial model
discriminator.trainable = False
adversarial = Model(inputs, discriminator(generator(inputs)), name = model_name)
adversarial.compile(loss = wassertian_loss,
                    optimizer = RMSprop(learning_rate = lr),
                    metrics = ['accuracy'])
adversarial.summary()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 discriminator_input (Input  [(None, 28, 28, 1)]       0         
 Layer)                                                          
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 28, 28, 1)         0         
                                                                 
 conv2d_4 (Conv2D)           (None, 14, 14, 32)        832       
                                                                 
 batch_normalization_7 (Bat  (None, 14, 14, 32)        128       
 chNormalization)                                                
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 14, 14, 32)        0         
                                                                 
 conv2d_5 (Conv2D)           (None, 7, 7, 64)        

In [8]:
def plot_images(generator,
                noise_input,
                noise_label=None,
                noise_codes=None,
                show=False,
                step=0,
                model_name="gan"):
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    rows = int(math.sqrt(noise_input.shape[0]))
    if noise_label is not None:
        noise_input = [noise_input, noise_label]
        if noise_codes is not None:
            noise_input += noise_codes

    images = generator.predict(noise_input)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

In [9]:
def train(models, x_train, params):

  generator, discriminator, adversarial = models
  batch_size, latent_size, n_critic, clip_value, train_steps, model_name = params
  save_interval = 500

  noise_input = np.random.uniform(-1.0, 1.0, size = [16, latent_size])
  train_size = x_train.shape[0]
  real_labels = np.ones((batch_size, 1))

  for i in range(train_steps):

    loss = 0
    acc = 0
    for _ in range(n_critic):

      rand_indexes = np.random.randint(0, train_size, size = batch_size)
      real_images = x_train[rand_indexes]
      noise = np.random.uniform(-1.0, 1.0, size = [batch_size, latent_size])
      fake_images = generator.predict(noise)
      real_loss, real_acc = discriminator.train_on_batch(real_images, real_labels)
      fake_loss, fake_acc = discriminator.train_on_batch(fake_images, -real_labels)

      # average
      loss += 0.5 * (real_loss + fake_loss)
      acc  += 0.5 * (real_acc + fake_acc)

      # clip weights
      for layer in discriminator.layers:
        weights = layer.get_weights()
        weights = [np.clip(weight, -clip_value, clip_value) for weight in weights]
        layer.set_weights(weights)

    loss /= n_critic
    acc /= n_critic
    log = f"{i}: [Discriminator Loss: {loss}, acc: {acc}]"
    print(log)

    noise = np.random.uniform(-1.0, 1.0, size = [batch_size, latent_size])
    loss, acc = adversarial.train_on_batch(noise, real_labels)
    log = f"{i}: [Adversarial Loss: {loss}, acc: {acc}]"
    print(log)

    if (i + 1) % save_interval == 0:
      plot_images(generator,
                  noise_input=noise_input,
                  show=False,
                  step=(i + 1),
                  model_name=model_name)

    generator.save(model_name + ".keras")

In [10]:
models = (generator, discriminator, adversarial)
params = (batch_size, latent_size, n_critic, clip_value, train_steps, model_name)
train(models, x_train, params)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
2286: [Discriminator Loss: -0.6408732652664184, acc: 0.0328125]
2286: [Adversarial Loss: -0.8805802464485168, acc: 1.0]
2287: [Discriminator Loss: -0.6365715071558953, acc: 0.028125]
2287: [Adversarial Loss: -0.8811934590339661, acc: 1.0]
2288: [Discriminator Loss: -0.625835494697094, acc: 0.03125]
2288: [Adversarial Loss: -0.8842266798019409, acc: 1.0]
2289: [Discriminator Loss: -0.6473916083574295, acc: 0.0359375]
2289: [Adversarial Loss: -0.8821587562561035, acc: 1.0]
2290: [Discriminator Loss: -0.6203341871500015, acc: 0.03125]
2290: [Adversarial Loss: -0.882789671421051, acc: 1.0]
2291: [Discriminator Loss: -0.6335417419672013, acc: 0.0328125]
2291: [Adversarial Loss: -0.884346067905426, acc: 1.0]
2292: [Discriminator Loss: -0.613951712846756, acc: 0.0296875]
2292: [Adversarial Loss: -0.8836376667022705, acc: 1.0]
2293: [Discriminator Loss: -0.6170762091875076, acc: 0.0296875]
2293: [Adversarial Loss: -0.882901012897