### imports

In [None]:
import os
# supress tensorflow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, ReLU, Dense, Flatten, Reshape, Input, Concatenate, Dropout
from tensorflow.keras.optimizers import Adam, SGD, Adadelta
import tensorflow as tf
import matplotlib.pyplot as plt
import time
from IPython import display
from utility.preprocess import get_images_path, augment_test, augment_data, load_image_test, load_image_train, load_image_test_random, load_image_test_user_input
import datetime
import numpy as np
from PIL import Image


In [None]:
# get all files in path 2

fullDataset = True
if fullDataset:
    path1 = get_images_path("/dcs/large/u2146727/greyscale")
    path2 = get_images_path("/dcs/large/u2146727/colourcropped")
else:
    path1 = get_images_path("/dcs/21/u2146727/cs310/dataset/edgecropped")
    path2 = get_images_path("/dcs/21/u2146727/cs310/dataset/colourcropped")




training_Data = tf.data.Dataset.from_tensor_slices((path1, path2))

#print all the images
training_Data = training_Data.map(augment_data)
training_Data = training_Data.batch(1)


test1 = get_images_path("/dcs/21/u2146727/cs310/dataset/greyscale_test")
test2 = get_images_path("/dcs/21/u2146727/cs310/dataset/colourcropped_test")
test_Data = tf.data.Dataset.from_tensor_slices((test1, test2))
test_Data = test_Data.map(augment_test).shuffle(100)
test_Data = test_Data.batch(1)




hint_train_path = get_images_path("/dcs/large/u2146727/kaggle/data/train")
hint_train = tf.data.Dataset.from_tensor_slices(hint_train_path)
hint_train = hint_train.map(lambda x: tf.py_function(load_image_train, [x], [tf.float32, tf.float32])).shuffle(100)
hint_train = hint_train.batch(1)

hint_test_path = get_images_path("/dcs/large/u2146727/kaggle/data/val_manual")
hint_test = tf.data.Dataset.from_tensor_slices(hint_test_path)
hint_test = hint_test.map(lambda x: tf.py_function(load_image_test_user_input, [x], [tf.float32, tf.float32])).shuffle(100)
hint_test = hint_test.batch(1)


In [None]:
# encoder block
def encoder_block2(filters, size = 4, bn=True, activation='leakyrelu', batchsize=1):
  initiailizer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05)
  encoder = tf.keras.Sequential([Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initiailizer, use_bias=False),] 
                                # batch normalisation should be higher in smaller batch sizes
                                # and lower in larger batch sizes
                                + ([BatchNormalization(momentum = 0.98)] if (bn and batchsize < 8) else [BatchNormalization(momentum = 0.9)] if (bn and batchsize >= 8) else [])
                                + ([LeakyReLU(0.25)] if activation == 'leakyrelu' else [ReLU()] if activation == 'relu' else []))
  return encoder


In [None]:
# decoder block
def decoder_block2(filters, size = 4, dropout=False, activation='leakyrelu', batchsize=1):
  initiailizer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05)
  result = tf.keras.Sequential([tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initiailizer, use_bias=False),
                               tf.keras.layers.BatchNormalization(),
                               tf.keras.layers.Dropout(0.5) if dropout else tf.keras.layers.Dropout(0.0),
                               tf.keras.layers.LeakyReLU() if activation == 'leakyrelu' else tf.keras.layers.ReLU()])
  return result

In [None]:
def Generator():
  initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05)
  filters = 64

  input = tf.keras.layers.Input(shape=[256, 256, 3])
  layer1 = encoder_block2(filters, bn=False)(input)

  layer2 = encoder_block2(filters*2)(layer1)
  layer3 = encoder_block2(filters*4)(layer2)
  layer4 = encoder_block2(filters*8)(layer3)
  layer5 = encoder_block2(filters*8)(layer4)
  layer6 = encoder_block2(filters*8)(layer5)
  layer7 = encoder_block2(filters*8)(layer6)
  layer8 = encoder_block2(filters*8)(layer7)

  up1 = decoder_block2(filters*8, dropout=True)(layer8)
  up1 = tf.keras.layers.Concatenate()([up1, layer7])
  up2 = decoder_block2(filters*8, dropout=True)(up1)
  up2 = tf.keras.layers.Concatenate()([up2, layer6])
  up3 = decoder_block2(filters*8)(up2)
  up3 = tf.keras.layers.Concatenate()([up3, layer5])
  up4 = decoder_block2(filters*8)(up3)
  up4 = tf.keras.layers.Concatenate()([up4, layer4])
  up5 = decoder_block2(filters*4)(up4)
  up5 = tf.keras.layers.Concatenate()([up5, layer3])
  up6 = decoder_block2(filters*2)(up5)
  up6 = tf.keras.layers.Concatenate()([up6, layer2])
  up7 = decoder_block2(filters)(up6)
  up7 = tf.keras.layers.Concatenate()([up7, layer1])


  tanhLayer = Conv2DTranspose(3, 4,strides=2,padding='same',kernel_initializer=initializer,activation='tanh')  # (batch_size, 256, 256, 3)
  output = tanhLayer(up7)

  return tf.keras.Model(inputs=input, outputs=output)

In [None]:
def Discriminator():
  initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05)
  size = 64
  label = tf.keras.layers.Input(shape=[256, 256, 3])
  tar = tf.keras.layers.Input(shape=[256, 256, 3])
  first_layer = tf.keras.layers.concatenate([label, tar]) 

  layer1 = encoder_block2(size, 4, False)(first_layer)
  layer2 = encoder_block2(size*2, 4)(layer1) 

  padded1 = tf.keras.layers.ZeroPadding2D()(layer2)
  encoder1 = Conv2D(512, 4, strides=1,kernel_initializer=initializer,use_bias=False)(padded1)
  batchnorm1 = BatchNormalization()(encoder1)
  leaky_relu1 = LeakyReLU()(batchnorm1)

  padded2 = tf.keras.layers.ZeroPadding2D()(leaky_relu1)
  encoder2 = Conv2D(512, 4, strides=1,kernel_initializer=initializer,use_bias=False)(padded2)
  batchnorm2 = BatchNormalization()(encoder2)
  leaky_relu2 = LeakyReLU()(batchnorm2)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu2) 
  sigmoid = Conv2D(1, 4, strides=1,kernel_initializer=initializer, activation='sigmoid')(zero_pad2)

  return tf.keras.Model(inputs=[label, tar], outputs=sigmoid)


In [None]:
# loss functions
def generator_loss(disc_generated_output, gen_output, target, lamb = 100, norm = "l1"):
  lossmatrix_one = tf.ones(disc_generated_output.shape, dtype=tf.float32)
  binaryCrossEntropy = tf.keras.losses.BinaryCrossentropy(label_smoothing=0) 
  gan_loss = binaryCrossEntropy(lossmatrix_one, disc_generated_output)
  # Mean absolute error
  if norm == "l1":
    norm_loss = tf.reduce_mean(tf.abs(gen_output - target))
  else:
    # Mean squared error
    norm_loss = tf.reduce_mean((target - gen_output)**2)
  total_gen_loss = gan_loss + (lamb * norm_loss)

  return total_gen_loss, gan_loss, norm_loss

def discriminator_loss(disc_real_output, disc_generated_output):
  lossmatrix_one = tf.ones(disc_real_output.shape, dtype=tf.float32)
  lossmatrix_zero = tf.zeros(disc_generated_output.shape, dtype=tf.float32)
  binaryCrossEntropy = tf.keras.losses.BinaryCrossentropy(label_smoothing=0)
  real_loss = binaryCrossEntropy(lossmatrix_one, disc_real_output)
  generated_loss = binaryCrossEntropy(lossmatrix_zero, disc_generated_output)
  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
def colourise(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(30,30))
  mse = 0
  psnr = 0
  ssim = 0

  display_list = [test_input[0], tar[0], prediction[0]]
  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.imshow(display_list[i] * 0.5 + 0.5)
    if i == 2:
      img = (display_list[i] * 0.5 + 0.5) * 255
      img = Image.fromarray(np.uint8(img))
      # print mean absolute error
      mse = tf.reduce_mean(tf.abs(display_list[i] - display_list[1]))
      psnr = tf.image.psnr(tar[0], prediction[0], max_val=2.0)
      ssim = tf.image.ssim(tar[0], prediction[0], max_val=2.0)
    plt.axis('off')
  plt.show()
  return mse, psnr, ssim

In [None]:
generator = Generator()
discriminator = Discriminator()
generator_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
    generated_colour_image = generator(input_image, training=True)
    discriminator_real_image = discriminator([input_image, target], training=True)
    discriminator_generated = discriminator([input_image, generated_colour_image], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(discriminator_generated, generated_colour_image, target)
    disc_loss = discriminator_loss(discriminator_real_image, discriminator_generated)

  generator_gradients = generator_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = discriminator_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))


In [None]:
def fit(train_ds, test_ds, steps):
  example_input, example_target = next(iter(test_ds.take(1)))

  for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
    if (step) % 1000 == 0:
      display.clear_output(wait=True)

      colourise(generator, example_input, example_target)
      print(f"Step: {step//1000} thousand")
    train_step(input_image, target, step)
    # Training step
    if (step+1) % 10 == 0:
      print('.', end='', flush=True)


In [None]:
# traning loop here
#fit(training_Data, test_Data, steps=80000)

#### Greyscale

In [None]:
greyscale_test = tf.data.Dataset.from_tensor_slices((test1, test2))
greyscale_test = greyscale_test.map(augment_test).shuffle(100)
greyscale_test = greyscale_test.batch(1)

new_model = tf.keras.models.load_model('generator_greyscale.h5', compile=False)
#get total, min and max pnsr
for inp, tar in greyscale_test.take(len(greyscale_test)):
    mse, psnr, ssim = colourise(new_model, inp, tar)

#### D of G

In [None]:
edge1 = get_images_path("/dcs/21/u2146727/cs310/dataset/edgecropped")
edge_test = tf.data.Dataset.from_tensor_slices((edge1, test2))
edge_test = edge_test.map(augment_test).shuffle(100)
edge_test = edge_test.batch(1)

new_model = tf.keras.models.load_model('generator_edges.h5', compile=False)
for inp, tar in edge_test.take(len(edge_test)):
    mse, psnr, ssim = colourise(new_model, inp, tar)

### Colour Hint

In [None]:
from PIL import Image

hint_test_path = get_images_path("/dcs/21/u2146727/cs310/dataset/val_manual/")
hint_test = tf.data.Dataset.from_tensor_slices(hint_test_path)
hint_test = hint_test.map(lambda x: tf.py_function(load_image_test_user_input, [x], [tf.float32, tf.float32])).shuffle(100)
hint_test = hint_test.batch(1)

new_model = tf.keras.models.load_model('generator-third.h5', compile = False)
#get total, min and max pnsr
for inp, tar in hint_test.take(len(hint_test)):
    mse, psnr, ssim = colourise(new_model, inp, tar)