In [None]:
import datetime
import os
import time

import matplotlib.pyplot as plt
import tensorflow as tf

from IPython import display

## Input Pipeline

In [None]:
DATASETS = {
  "mirflickr25k": {
    "url": "http://press.liacs.nl/mirflickr/mirflickr25k.v3/mirflickr25k.zip",
    "epochs": 50,
    "path": "mirflickr/",
    "files": "*.jpg"
  },
  "flowers": {
    "url": "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz",
    "epochs": 100,
    "path": "jpg/",
    "files": "*.jpg"
  },
  "landscapes": {  # URL from https://www.kaggle.com/arnaud58/landscape-pictures expires
    "url": "https://storage.googleapis.com/kaggle-data-sets/298806/1217826/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20210225%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20210225T150653Z&X-Goog-Expires=259199&X-Goog-SignedHeaders=host&X-Goog-Signature=5dff855ceda0e5e1092fd5a9c0d3c4293d76b6b2f3bd85d92964c980b5714c56f7431e5a5ccb0492363c7b9bef14f200a57d6957f47bd561d172ecfecd60fbdbfbb8240dd676bac543618c92395979db6d26cf317e22c56d5763c2051cb3f3ffe5b8067e72c447f91d12a56f4e3a8cf1ee3ca4635f0618f0da47ed977f20a06f809f863968a1a574622fe047b675e28d5b14a27e149d940fa762fd62c6f226479340e312454d8dfaf99658e3416584d50a83fc68cc3536c0d3e20ee0e5db9e6ffdfd684d54d1c508018f8bcc21cf147e06a1706ab5e77f4919f3a35b9446d2554e79ca4f627d70b375d4ff23e56c72ace3d5067e454f150a61f66b67808fbba6",
    "epochs": 100,
    "path": "",
    "files": "*.jpg"
  }
}
DATASET = DATASETS["flowers"]  # Change dataset here

path_to_zip = tf.keras.utils.get_file("dataset.zip", origin=DATASET["url"], extract=True, cache_dir="/content")
PATH = os.path.join(os.path.dirname(path_to_zip), DATASET["path"])

In [None]:
BUFFER_SIZE = 400
BATCH_SIZE = 32
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
def resize_image_keep_aspect(image, lo_dim=IMG_WIDTH):
  # Take width and height
  initial_width = tf.shape(image)[0]
  initial_height = tf.shape(image)[1]

  # Take the greater value, and use it for the ratio
  min_ = tf.minimum(initial_width, initial_height)
  ratio = tf.cast(min_, dtype=tf.float32) / tf.constant(lo_dim, dtype=tf.float32)

  new_width = tf.cast(tf.cast(initial_width, dtype=tf.float32) / ratio, dtype=tf.int32)
  new_height = tf.cast(tf.cast(initial_height, dtype=tf.float32) / ratio, dtype=tf.int32)
  
  # Resize
  return tf.image.resize(image, [new_width, new_height])

def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_image(image, channels=3, expand_animations=False)

  image = resize_image_keep_aspect(image)
  image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  # Convert to grayscale
  input_image = tf.image.rgb_to_grayscale(image)

  input_image = tf.cast(input_image, tf.float32)
  image = tf.cast(image, tf.float32)

  return input_image, image

def normalize(image):
  return (image / 127.5) - 1

def unnormalize(image):
  return (image + 1) * 127.5

def load_image(image_file):
  input_image, real_image = load(image_file)

  input_image = normalize(input_image)
  real_image = normalize(real_image)

  return input_image, real_image

In [None]:
dataset = tf.data.Dataset.list_files(PATH + DATASET["files"], shuffle=False)

dataset_length = dataset.cardinality().numpy()
train_length = int(dataset_length * 0.9)
test_length = dataset_length - train_length

train_dataset = dataset.take(train_length)
test_dataset = dataset.skip(train_length)

a = set(train_dataset.as_numpy_iterator())
b = set(test_dataset.as_numpy_iterator())
assert all([el not in a for el in b])
assert all([el not in b for el in a])

train_dataset = train_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_dataset = test_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

## Generator and discriminator

In [None]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)
  )

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)
  )

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[256, 256, 1])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
    downsample(128, 4), # (bs, 64, 64, 128)
    downsample(256, 4), # (bs, 32, 32, 256)
    downsample(512, 4), # (bs, 16, 16, 512)
    downsample(512, 4), # (bs, 8, 8, 512)
    downsample(512, 4), # (bs, 4, 4, 512)
    downsample(512, 4), # (bs, 2, 2, 512)
    downsample(512, 4), # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512, 4), # (bs, 16, 16, 1024)
    upsample(256, 4), # (bs, 32, 32, 512)
    upsample(128, 4), # (bs, 64, 64, 256)
    upsample(64, 4), # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(3, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

generator = Generator()

In [None]:
LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 1], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
  down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
  down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

discriminator = Discriminator()

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Generate images

In [None]:
def generate_images(model, input, target, filename=None):
  prediction = model(input, training=True)

  input = unnormalize(input)
  target = unnormalize(target)
  prediction = unnormalize(prediction)

  fig = plt.figure(figsize=(15, 5))

  plt.subplot(1, 3, 1)
  plt.title("Ground Truth")
  plt.imshow(target[0,:,:,:]/255.0)
  plt.axis('off')

  plt.subplot(1, 3, 2)
  plt.title("Input Image")
  plt.imshow(input[0,:,:,0]/255.0, cmap=plt.get_cmap("gray"))
  plt.axis('off')
  
  plt.subplot(1, 3, 3)
  plt.title("Predicted Image")
  plt.imshow(prediction[0,:,:,:]/255.0)
  plt.axis('off')

  plt.show()

## Training

In [None]:
log_dir="logs/"
summary_writer = tf.summary.create_file_writer(log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

checkpoint_dir = './training_checkpoints'
checkpoint = tf.train.Checkpoint(
  generator_optimizer=generator_optimizer,
  discriminator_optimizer=discriminator_optimizer,
  generator=generator,
  discriminator=discriminator
)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_dir, max_to_keep=5)

@tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)

In [None]:
def fit(train_ds, epochs, test_ds):
  for epoch in range(1, epochs + 1):
    progbar = tf.keras.utils.Progbar(train_length)
    start = time.time()

    # Train
    c = 0
    batch_seen = 0
    for n, (input_image, target) in train_ds.enumerate():
      if c == 0 or c >= 1000:
        c = 0
        display.clear_output(wait=True)

        print("Epoch:", epoch, "/", epochs)

        for example_input, example_target in test_ds.take(1):
          generate_images(generator, example_input, example_target)
      train_step(input_image, target, epoch)
      c += BATCH_SIZE
      batch_seen += BATCH_SIZE
      progbar.update(batch_seen)

    # Save the model at every epoch
    # Uncomment to save checkpoints
    # checkpoint_manager.save()

    print (f"Time taken for epoch {epoch} is {time.time()-start} sec\n")

In [None]:
fit(train_dataset, DATASET["epochs"], test_dataset)