In [23]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import datetime,os
from tensorflow.keras import layers
import time
import math

In [24]:
is_wgan = False                # set this to loss (cross-entropy loss or Wasserstein loss)
is_complex_generator = False   # set this true to use more complex generator

In [25]:
# set seeds for reproducibility
np.random.seed(1)
tf.random.set_seed(1)

In [None]:
%load_ext tensorboard

In [27]:
# load dataset
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

In [28]:
#set parameters
latent_dim = 50
EPOCHS = 50
BATCH_SIZE = 128
checkpoint_path = "/content/checkpoints/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

In [29]:
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*32, use_bias=False, input_shape=(latent_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 32)))

    model.add(layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(16, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh'))

    return model
  

In [30]:
def make_complex_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*64, use_bias=False, input_shape=(latent_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 64)))
    assert model.output_shape == (None, 7, 7, 64)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 28, 28, 32)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

In [31]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model


In [32]:
def generate_models(summary=False):
  if is_complex_generator:
    generator = make_complex_generator_model()
  else:
    generator = make_generator_model()
    
  if summary:  
    generator.summary()
  
  discriminator = make_discriminator_model()
  if summary:
    discriminator.summary()

  return generator,discriminator

In [33]:
class GAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
    ):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        
    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
      super(GAN, self).compile()
      self.d_optimizer = d_optimizer
      self.g_optimizer = g_optimizer
      self.d_loss_fn = d_loss_fn
      self.g_loss_fn = g_loss_fn

    def train_step(self, real_images):
      if isinstance(real_images, tuple):
          real_images = real_images[0]

      # Get the batch size
      batch_size = tf.shape(real_images)[0]
      noise = tf.random.normal([batch_size, self.latent_dim])

      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = self.generator(noise, training=True)

        real_output = self.discriminator(real_images, training=True)
        fake_output = self.discriminator(generated_images, training=True)

        gen_loss = self.g_loss_fn(fake_output)
        disc_loss = self.d_loss_fn(real_output, fake_output)

      gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
      gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)

      self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
      self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
      return {"d_loss": disc_loss, "g_loss": gen_loss}

In [34]:
class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images):
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(random_latent_vectors, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for the real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

In [None]:
# tensorboard to view loss graphs
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
print(logdir)

In [None]:
# run tensorboard
%tensorboard --logdir logs

In [37]:
def build_gan():
  # build models
  generator,discriminator=generate_models()

  # Build GAN model.
  gan = GAN(
      discriminator=discriminator,
      generator=generator,
      latent_dim=latent_dim
  )

  # generate optimizers and loss functions
  generator_optimizer = tf.keras.optimizers.Adam(1e-4)
  discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

  # Helper function to compute cross entropy loss
  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

  def discriminator_loss(real_output, fake_output):
      real_loss = cross_entropy(tf.ones_like(real_output), real_output)
      fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
      total_loss = real_loss + fake_loss
      return total_loss

  # for generator, we assign labels as 1 because it is trained to maximize discriminator loss,
  # equivalently, it minimize the loss with opposite label (1 for fake images)
  def generator_loss(fake_output):
      return cross_entropy(tf.ones_like(fake_output), fake_output)


  # Compile the WGAN model.
  gan.compile(
      d_optimizer=discriminator_optimizer,
      g_optimizer=generator_optimizer,
      g_loss_fn=generator_loss,
      d_loss_fn=discriminator_loss,
  )

  return gan

In [38]:
def train_gan():

  # get gan model
  gan = build_gan()

  # create checkpoints
  STEPS_PER_EPOCH = int(math.ceil(train_images.shape[0] / BATCH_SIZE))
  SAVE_PERIOD = 10 # save after each SAVE_PERIOD epochs
  
  model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
      filepath=checkpoint_path,
      save_weights_only=True,verbose=1,save_freq=SAVE_PERIOD*STEPS_PER_EPOCH)

  # start training
  gan.fit(train_images, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[tensorboard_callback,model_checkpoint_callback])

  return gan


In [39]:
def build_wgan():
  
  # Build the optimizer for both networks
  generator_optimizer = keras.optimizers.Adam(
      learning_rate=0.0002, beta_1=0.5, beta_2=0.9
  )
  discriminator_optimizer = keras.optimizers.Adam(
      learning_rate=0.0002, beta_1=0.5, beta_2=0.9
  )

  # Define the loss functions for the discriminator,
  # min (fake_loss - real_loss).
  def discriminator_loss(real_img, fake_img):
      real_loss = tf.reduce_mean(real_img)
      fake_loss = tf.reduce_mean(fake_img)
      return fake_loss - real_loss


  # Define the loss functions for the generator.
  # min (-fake_loss)
  def generator_loss(fake_img):
      return -tf.reduce_mean(fake_img)

  # build models
  generator,discriminator=generate_models()

  # Instantiate the WGAN model.
  wgan = WGAN(
      discriminator=discriminator,
      generator=generator,
      latent_dim=latent_dim,
      discriminator_extra_steps=3,
  )

  # Compile the WGAN model.
  wgan.compile(
      d_optimizer=discriminator_optimizer,
      g_optimizer=generator_optimizer,
      g_loss_fn=generator_loss,
      d_loss_fn=discriminator_loss,
  )

  return wgan

In [40]:
def train_wgan():

  # get gan model
  wgan = build_wgan()

  # create checkpoints
  STEPS_PER_EPOCH = int(math.ceil(train_images.shape[0] / BATCH_SIZE))
  SAVE_PERIOD = 10 # save after each SAVE_PERIOD epochs
  
  model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
      filepath=checkpoint_path,
      save_weights_only=True,verbose=1,save_freq=SAVE_PERIOD*STEPS_PER_EPOCH)
  
  # Start training the model.
  wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[tensorboard_callback,model_checkpoint_callback])

  return wgan


In [41]:
def plot_images(generator):

  num = 100 # number of images to generate
  num_row = 10
  num_col = 10

  # draw std. normal variables from latent space
  np.random.seed(0)
  rand_vars = np.random.normal(size=(num,latent_dim))

  # plot images
  fig, axes = plt.subplots(num_row, num_col, figsize=(1.5*num_col,2*num_row))
  for i in range(num):
      z_sample = np.array([rand_vars[i]])
      predictions = generator.predict(z_sample)
      digit = predictions[0, :, :, 0] * 127.5 + 127.5
      ax = axes[i//num_col, i%num_col]
      ax.imshow(digit, cmap='gray')
  plt.tight_layout()
  plt.show()


In [42]:
def get_checkpoint_model(epoch):
  cp = checkpoint_path.format(epoch=epoch)
  gan = build_gan()
  gan.load_weights(cp)
  return gan


In [None]:
# start training
if is_wgan:
  gan = train_wgan()
else:
  gan = train_gan()

In [None]:
plot_images(gan.generator) # plot images generated from model after all epochs

In [None]:
plot_images(get_checkpoint_model(epoch = 10).generator)  # plot images for specific checkpoint models (epochs with multiple of 10s) 

In [None]:
# === BELOW CODES ARE FOR SAVING AND DOWNLOADING TRAINED MODELS ===

In [None]:
# save trained model (after all epochs)
gan.generator.save("gan_generator")
gan.discriminator.save("gan_discriminator")

In [None]:
# zip saved model folders to download
!zip -r /content/gan_generator.zip /content/gan_generator
!zip -r /content/gan_discriminator.zip /content/gan_discriminator

In [None]:
# save trained checkpoint models
epoch = 10
get_checkpoint_model(epoch = epoch).generator.save("gan_generator-"+str(epoch))
get_checkpoint_model(epoch = epoch).discriminator.save("gan_discriminator-"+str(epoch))


In [None]:
# zip saved model folders to download
!zip -r /content/gan_generator-10.zip /content/gan_generator-10
!zip -r /content/gan_discriminator-10.zip /content/gan_discriminator-10

In [None]:
# zip checkpoint folder to download 
!zip -r /content/checkpoints.zip /content/checkpoints
