Install additional packages

In [None]:
# To generate GIFs
!pip install imageio
!pip install git+https://github.com/tensorflow/docs
!pip install tensorflow-addons

Imports

In [None]:
import glob
import matplotlib.pyplot as plt
import imageio
import numpy as np
import os
import PIL
import time
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers

from IPython import display


Parameters

In [None]:
plt.rcParams["figure.figsize"] = (7,8)

In [None]:
IMAGE_SHAPE=(32,32,3)
BLOCK_SIZE = (16, 16)
BLOCK_SHAPE =(16,16,3)

# Dataset

load dataset

In [None]:
(y_train, _), (y_test,_) = tf.keras.datasets.cifar10.load_data()
y_train = tf.convert_to_tensor(y_train, dtype=tf.float32)
y_test = tf.convert_to_tensor(y_test, dtype=tf.float32)
y_train = (y_train ) / 255.  # Normalize the images to [0, 1]
y_test = (y_test ) / 255.  # Normalize the images to [0, 1]

cutout 

In [None]:
x_train = tf.identity(y_train)
x_test = tf.identity(y_test)

In [None]:
class RandomCutout():
  def __init__(self, mask_size=(16,16), border=(0,0), name = 'random_cutout', **kwargs):
    super(RandomCutout, self).__init__(**kwargs)
    
    self.mask_size = mask_size
    self.border = border

  def __call__(self, image_batch):
      x = tf.shape(image_batch)[1]
      y = tf.shape(image_batch)[2]

      xoffset = tf.cast(tf.math.ceil(self.mask_size[0] / 2.) + self.border[0], dtype=tf.int32)
      yoffset = tf.cast(tf.math.ceil(self.mask_size[1] / 2.) + self.border[1], dtype=tf.int32)
      xmin, xmax = xoffset, x - xoffset
      ymin, ymax = yoffset, y - yoffset

      if xmin < xmax:
        xoffset = tf.random.uniform(shape=[], minval=xmin, maxval=xmax, dtype=tf.dtypes.int32)
      else:
        xoffset = tf.cast(x / 2, dtype=tf.int32)
      if ymin < ymax:
        yoffset = tf.random.uniform(shape=[], minval=ymin, maxval=ymax, dtype=tf.dtypes.int32)
      else:
        yoffset = tf.cast(y / 2, dtype=tf.int32)

      xmin, xmax = xoffset - tf.cast(tf.math.ceil(self.mask_size[0] / 2.), dtype=tf.int32), xoffset + tf.cast(tf.math.ceil(self.mask_size[0] / 2.), dtype=tf.int32)
      ymin, ymax = yoffset - tf.cast(tf.math.ceil(self.mask_size[1] / 2.), dtype=tf.int32), yoffset + tf.cast(tf.math.ceil(self.mask_size[1] / 2.), dtype=tf.int32)

      mask = tfa.image.cutout(tf.zeros_like(image_batch), mask_size=self.mask_size, offset=(yoffset,xoffset), constant_values=1.)
      context = image_batch * (1 - mask)
      random_block = image_batch * mask
      return context, random_block, mask, (ymin, xmin, ymax-ymin, xmax-xmin)

In [None]:
random_cutout = RandomCutout(mask_size=BLOCK_SIZE, border=(2,2))

# Models


## generator

In [None]:
def Generator(input_shape=(32,32,3)):
    
    # generator architecture
    model = tf.keras.Model(inputs, outputs, name="Generator")

    return model

## discriminator

In [None]:
def Discriminator(input_shape=(16,16,3)):
    
    # discriminator architecture
    model = tf.keras.Model(inputs, outputs, name="Discriminator")
    return model

discriminator = Discriminator()
discriminator.summary()

## metrics

In [None]:
class psnr_metric(tf.keras.metrics.Metric):
    def __init__(self, name = 'psnr', **kwargs):
        super(psnr_metric, self).__init__(**kwargs)
        self.value = self.add_weight('value', initializer = 'zeros')
        self.count = self.add_weight('count', initializer = 'zeros')

    def update_state(self, y_true, y_pred,sample_weight=None):
        self.value.assign_add(tf.reduce_mean(tf.image.psnr(y_true, y_pred, 2)))
        self.count.assign_add(1)

    def reset_state(self):
        self.value.assign(0)
        self.count.assign(0)

    def result(self):
        return self.value / self.count


In [None]:
class ssim_metric(tf.keras.metrics.Metric):
    def __init__(self, name = 'ssim', **kwargs):
        super(ssim_metric, self).__init__(**kwargs)
        self.value = self.add_weight('value', initializer = 'zeros')
        self.count = self.add_weight('count', initializer = 'zeros')

    def update_state(self, y_true, y_pred,sample_weight=None):
        self.value.assign_add(tf.reduce_mean(tf.image.ssim(y_true, y_pred, 2)))
        self.count.assign_add(1)

    def reset_state(self):
        self.value.assign(0)
        self.count.assign(0)

    def result(self):
        return self.value / self.count


# Context encoder class

In [None]:
class ContextEncoder(tf.keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""
    def __init__(self, 
                  discriminator,
                  generator,
                  generator_extra_steps=1,
                  discriminator_extra_steps=1,
                  mask_size=(16,16),
                  border=(2,2),
                  name="context_encoder"):
      super(ContextEncoder, self).__init__()

      self.discriminator = discriminator
      self.generator = generator

      self.d_steps = discriminator_extra_steps
      self.g_steps = generator_extra_steps

      self.psnr_metric = psnr_metric(name="psnr")
      self.ssim_metric = ssim_metric(name="ssim")

      self.mask_size = mask_size
      self.border = border


    def compile(self, generator_optimizer, discriminator_optimizer, discriminator_loss, adversarial_loss, reconstruction_loss, lam=0.99):
      super(ContextEncoder, self).compile()
      self.generator_optimizer = generator_optimizer
      self.discriminator_optimizer = discriminator_optimizer
      self.adversarial_loss = adversarial_loss
      self.reconstruction_loss = reconstruction_loss
      self.discriminator_loss = discriminator_loss
      self.lam = lam

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        # metrics = super().metrics
        # metrics.append(self.psnr_metric)
        return [self.psnr_metric, self.ssim_metric]

    @tf.function
    def test_step(self, data):
      if isinstance(data, tuple):
            x_batch = data[0]
            y_batch = data[1]

      context, random_region, mask, coords = random_cutout(x_batch)
      context_white = context + mask
      real_images = y_batch

      # Generate fake images from the latent vector
      generated = self.generator(context_white, training=False)
      fake_images = context + generated * mask

      fake_block = tf.image.crop_to_bounding_box(fake_images, *coords)
      real_block = tf.image.crop_to_bounding_box(real_images, *coords)
      # Get the logits for the fake images
      fake_logits = self.discriminator(fake_block, training=False)
      # Get the logits for the real images
      real_logits = self.discriminator(real_block, training=False)

      # Calculate loss
      d_loss = self.discriminator_loss(real_logits, fake_logits)
      g_loss = self.lam * self.reconstruction_loss(real_block, fake_block) + (1-self.lam) * self.adversarial_loss(fake_logits)

      self.psnr_metric.update_state(real_images,fake_images)
      self.ssim_metric.update_state(real_images,fake_images)
      return {"g_loss": g_loss,"d_loss": d_loss, "psnr": self.psnr_metric.result(), "ssim": self.ssim_metric.result()}


    @tf.function
    def train_step(self, data):
      if isinstance(data, tuple):
            x_batch = data[0]
            y_batch = data[1]

      context, random_region,mask, coords = random_cutout(x_batch)
      context_white = context + mask
      real_images = y_batch

      
      for i in range(self.d_steps):
        with tf.GradientTape() as tape:
          # Generate fake images from the latent vector
          generated = self.generator(context_white, training=True)
          fake_images = context + generated * mask

          fake_block = tf.image.crop_to_bounding_box(fake_images, *coords)
          real_block = tf.image.crop_to_bounding_box(real_images, *coords)
          # Get the logits for the fake images
          fake_logits = self.discriminator(fake_block, training=True)
          # Get the logits for the real images
          real_logits = self.discriminator(real_block, training=True)

          # Calculate loss
          d_loss = self.discriminator_loss(real_logits, fake_logits)

        # Get the gradients w.r.t the discriminator loss
        d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
        # Update the weights of the discriminator using the discriminator optimizer
        self.discriminator_optimizer.apply_gradients(zip(d_gradient, self.discriminator.trainable_variables))

      for i in range(self.g_steps):
        with tf.GradientTape() as tape:
          # Generate fake images using the generator
          generated = self.generator(context_white, training=True)
          fake_images = context + generated * mask
          fake_block = tf.image.crop_to_bounding_box(fake_images, *coords)
          real_block = tf.image.crop_to_bounding_box(real_images, *coords)
          # Get the discriminator logits for fake images
          gen_img_logits = self.discriminator(fake_block, training=True)
          # Calculate the generator loss
          g_loss = self.lam * self.reconstruction_loss(real_block, fake_block) + (1-self.lam) * self.adversarial_loss(fake_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.generator_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))

      # Compute our own metrics
      self.psnr_metric.update_state(real_images,fake_images)
      self.ssim_metric.update_state(real_images,fake_images)
      return {"g_loss": g_loss,"d_loss": d_loss, "psnr": self.psnr_metric.result(), "ssim": self.ssim_metric.result()}

# Compile

## loss

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy()
mse = tf.keras.losses.MeanSquaredError()
mae = tf.keras.losses.MeanAbsoluteError()

In [None]:
def discriminator_loss(real_preds, fake_preds):
  real_loss = cross_entropy(tf.ones_like(real_preds), real_preds)
  fake_loss = cross_entropy(tf.zeros_like(fake_preds), fake_preds)
  total_loss = real_loss + fake_loss
  return total_loss

In [None]:
def adv_loss(fake_preds):
  return cross_entropy(tf.ones_like(fake_preds), fake_preds)

def rec_loss(y_true, y_pred):
  return mse(y_true, y_pred)     # MSE / L2
  # return mae(y_true, y_pred)  # MAE / L1

## optimizers

In [None]:

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9)          
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-7, beta_1=0.9)  

## model compile

In [None]:
MASK_SIZE = (16,16)
BORDER = (2,2)
ctxtenc = ContextEncoder(discriminator=Discriminator(), 
                         generator= Generator(),
                         mask_size=MASK_SIZE,
                         border=BORDER)

ctxtenc.compile(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    discriminator_loss=discriminator_loss,
    adversarial_loss = adv_loss,
    reconstruction_loss=rec_loss,
    lam = 0.999)

# Training

In [None]:
images_to_show = x_test[:25]
context, random_block, mask, coords = RandomCutout(mask_size=(16,16),border=(8,8))(test_batch)
fig = plt.figure()
plt.suptitle("Input", fontsize=14)
for i, img in enumerate(context+mask):
    plt.subplot(5,5 ,i+1)
    plt.imshow(img)
    plt.axis('off')

plt.savefig('image_at_epoch_0000.png')
plt.show()

In [None]:
class GANMonitor(tf.keras.callbacks.Callback):
    def __init__(self, images_to_show=None, save_every=1,mask_size=(16,16),border=(8,8)):
        self.images_to_show = images_to_show
        self.save_every = save_every
        self.random_cutout = RandomCutout(mask_size, border)
        print(self.images_to_show.shape)

    def on_epoch_end(self, epoch, logs=None):
      if (epoch + 1) % self.save_every == 0:
        
        context, random_region, mask, coords = self.random_cutout(self.images_to_show)
        
        context_white = context + mask
        generated = self.model.generator(context_white, training=False)

        reconstructed_images = context + generated * mask

        plt.ioff()
        fig = plt.figure()
        plt.suptitle(f"epoch:{epoch+1}")
        for i in range(reconstructed_images.shape[0]):
            plt.subplot(5,5, i+1)
            plt.imshow(reconstructed_images[i])
            plt.axis('off')
        plt.tight_layout()
        plt.savefig('image_at_epoch_{:04d}.png'.format(epoch+1))
        plt.close(fig)

show_images = GANMonitor(images_to_show=images_to_show, save_every=1, mask_size=(16,16), border=(8,8))

In [None]:
class CustomLearningRateScheduler(tf.keras.callbacks.Callback):

    def __init__(self, schedule):
        super(CustomLearningRateScheduler, self).__init__()
        self.schedule = schedule

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.discriminator_optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.discriminator_optimizer.learning_rate))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(epoch, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.discriminator_optimizer.lr, scheduled_lr)
        if scheduled_lr != lr:
          print("\nEpoch %d: Learning rate is %6.4f." % (epoch+1, scheduled_lr))

LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (30, 1e-5),
]


def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr

scheduler = CustomLearningRateScheduler(lr_schedule)

In [None]:
EPOCHS = 50
BATCH_SIZE = 128
history = ctxtenc.fit(x=x_train,
                      y=y_train,
                      validation_data=(x_test, y_test),
                      epochs=EPOCHS, 
                      initial_epoch = 0,
                      batch_size = BATCH_SIZE,
                      shuffle=True, 
                      callbacks=[show_images, scheduler])

# Visualization

In [None]:
test_batch = y_test[:25]
context, random_block, mask, coords = random_cutout(test_batch)
context_white = context + mask
reconstructed = context + ctxtenc.generator(context_white, training=False)*mask

fig = plt.figure()
plt.suptitle("Output", fontsize=14)
for i, img in enumerate(reconstructed):
    plt.subplot(5,5 ,i+1)
    plt.imshow(img)
    plt.axis('off')

plt.show()

## plots

In [None]:
## Plot train and validation curves
g_loss = history.history['g_loss']
val_g_loss =  history.history['val_g_loss']

d_loss =  history.history['d_loss']
val_d_loss =  history.history['val_d_loss']

psnr =  history.history['psnr']
val_psnr =  history.history['val_psnr']

ssim =  history.history['ssim']
val_ssim =  history.history['val_ssim']

In [None]:
plt.figure(figsize=(14,10))
plt.subplot(2, 2, 1)
plt.plot(g_loss, label='g_loss')
plt.plot(val_g_loss, label='val_g_loss')
plt.legend(loc='lower right')
plt.ylabel('Loss')
plt.ylim([0,.1])
plt.title('Training and Validation Generator Loss')

plt.subplot(2, 2, 2)
plt.plot(d_loss, label='d_loss')
plt.plot(val_d_loss, label='val_d_loss')
plt.legend(loc='upper right')
plt.ylabel('Loss')
plt.ylim([0,2])
plt.title('Training and Validation Discriminator Loss')
plt.xlabel('epoch')

plt.subplot(2, 2, 3)
plt.plot(psnr, label='psnr')
plt.plot(val_psnr, label='val_psnr')
plt.legend(loc='upper right')
plt.ylabel('Metric')
plt.ylim([20,30])
plt.title('Training and Validation Metric')
plt.xlabel('epoch')


plt.subplot(2,2, 4)
plt.plot(ssim, label='ssim')
plt.plot(val_ssim, label='val_ssim')
plt.legend(loc='upper right')
plt.ylabel('Metric')
plt.ylim([0,1])
plt.title('Training and Validation Metric')
plt.xlabel('epoch')

plt.show()

## last epoch output

In [None]:
# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(history.epoch[-1]+1)

## gif

In [None]:
anim_file = 'dcgan.gif'

filenames = glob.glob('image*.png')
filenames = sorted(filenames)
frames = []
for i,filename in enumerate(filenames):
  if i % 1 == 0:
    image = imageio.imread(filename)
    frames.append(image)
imageio.mimsave(anim_file, frames, format='GIF', fps=4)


In [None]:
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)
