In [None]:
import h5py
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
from keras.layers import Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, Concatenate, Dropout, Dense, Flatten
from keras.activations import sigmoid
import time
from IPython import display
import os
import pandas as pd
from zipfile import ZipFile


In [None]:
def extract_zip(input_zip):
    input_zip=ZipFile(input_zip)
    return {name: input_zip.read(name) for name in input_zip.namelist()}

def parse_function(filename, label):
    image_string = tf.io.read_file(filename)
    label_string = tf.io.read_file(label)

    #Don't use tf.image.decode_image, or the output shape will be undefined
    image = tf.image.decode_jpeg(image_string, channels=3)
    label = tf.image.decode_jpeg(label_string, channels=1)

    #This will convert to float values in [0, 1]
    image = tf.image.convert_image_dtype(image, tf.float32)
    label = tf.image.convert_image_dtype(label, tf.float32)

    image = tf.image.resize(image, [128, 416], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    label = tf.image.resize(label, [128, 416], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    image = (image / 127.5) - 1
    label = (label / 127.5) - 1

    return image, label

'''
def train_preprocess(image, label):
    image = tf.image.random_flip_left_right(image)

    image = tf.image.random_brightness(image, max_delta=32.0 / 255.0)
    image = tf.image.random_saturation(image, lower=0.5, upper=1.5)

    #Make sure the image is still in [0, 1]
    image = tf.clip_by_value(image, 0.0, 1.0)

    return image, label
'''

In [None]:
data = extract_zip('nyu_data.zip')
nyu2_train = list((row.split(',') for row in (data['data/nyu2_train.csv']).decode("utf-8").split('\n') if len(row) > 0))
nyu2_test = list((row.split(',') for row in (data['data/nyu2_test.csv']).decode("utf-8").split('\n') if len(row) > 0))

In [None]:
nyu2_train_image = []
nyu2_train_label = []
for e in range(len(nyu2_train)):
  nyu2_train_image.append(nyu2_train[e][0])
  nyu2_train_label.append(nyu2_train[e][1])

nyu2_test_image = []
nyu2_test_label = []
for e in range(len(nyu2_test)):
  nyu2_test_image.append(nyu2_test[e][0])
  nyu2_test_label.append(nyu2_test[e][1])


In [None]:
data = tf.data.Dataset.from_tensor_slices((nyu2_train_image, nyu2_train_label))
data = data.map(parse_function, num_parallel_calls=tf.data.AUTOTUNE)
data = data.shuffle(1024)
train_dataset = data.take(50000)
test_dataset = data.skip(50000)
train_dataset = train_dataset.batch(32)
test_dataset = test_dataset.batch(32)

In [None]:
def downsample(filters, size, stride, apply_batchnorm = True):

  initializer = tf.random_normal_initializer(0., 0.02)
  result = keras.Sequential()
  if filters == 1:
      result.add(Conv2D(filters = filters, kernel_size = size,strides = stride, padding = 'same',
                        activation = 'sigmoid', kernel_initializer = initializer, use_bias = False))
  else:
    result.add(Conv2D(filters = filters, kernel_size = size,strides = stride, padding = 'same', kernel_initializer = initializer, use_bias = False))

  if apply_batchnorm:
    result.add(BatchNormalization())
  
  result.add(LeakyReLU())
  
  return result

In [None]:
def upsample(filters, size, stride, apply_dropout = True):
  initializer = tf.random_normal_initializer(0., 0.02)
  result = keras.Sequential()
  result.add(Conv2DTranspose(filters = filters, kernel_size = size, strides = stride, padding = 'same', kernel_initializer = initializer, use_bias = False))
  result.add(BatchNormalization())

  if apply_dropout:
    result.add(Dropout(0.5))
  
  result.add(LeakyReLU())
  return result


In [None]:
def resize_like(inputs, ref):
    iH, iW = inputs.get_shape()[1], inputs.get_shape()[2]
    rH, rW = ref.get_shape()[1], ref.get_shape()[2]

    if iH == rH and iW == rW:
        return inputs
    return tf.image.resize(inputs, [rH, rW], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)

DISP_SCALING = 10
MIN_DISP = 0.01
H = 128
W = 416

In [None]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[128, 416, 3])

  cnv = inputs
  cnv1 = downsample(32, 7, 2, apply_batchnorm = False)(cnv)
  cnv1b = downsample(32, 7, 1)(cnv1)
  cnv2 = downsample(64, 5, 2, apply_batchnorm = False)(cnv1b)
  cnv2b = downsample(64, 5, 1)(cnv2)
  cnv3 = downsample(128, 3, 2, apply_batchnorm = False)(cnv2b)
  cnv3b = downsample(128, 3, 1)(cnv3)
  cnv4 = downsample(256, 3, 2, apply_batchnorm = False)(cnv3b)
  cnv4b = downsample(256, 3, 1)(cnv4)
  cnv5 = downsample(512, 3, 2, apply_batchnorm = False)(cnv4b)
  cnv5b = downsample(512, 3, 1)(cnv5)
  cnv6 = downsample(512, 3, 2, apply_batchnorm = False)(cnv5b)
  cnv6b = downsample(512, 3, 1)(cnv6)
  cnv7 = downsample(512, 3, 2, apply_batchnorm = False)(cnv6b)
  cnv7b = downsample(512, 3, 1)(cnv7)

  upcnv7 = upsample(512, 3, 2)(cnv7b)
  upcnv7 = resize_like(upcnv7, cnv6b)
  i7_in = Concatenate()([upcnv7, cnv6b])      ### if not working use tf.concat([upcnv7, cnv6b], axis = 3)
  icnv7 = downsample(512, 3, 1)(i7_in)

  upcnv6 = upsample(512, 3, 2)(icnv7)
  upcnv6 = resize_like(upcnv6, cnv5b)
  i6_in = Concatenate()([upcnv6, cnv5b])      
  icnv6 = downsample(512, 3, 1)(i6_in)

  upcnv5 = upsample(256, 3, 2)(icnv6)
  upcnv5 = resize_like(upcnv5, cnv4b)
  i5_in = Concatenate()([upcnv5, cnv4b])      
  icnv5 = downsample(256, 3, 1)(i5_in)

  upcnv4 = upsample(128, 3, 2)(icnv5)
  i4_in = Concatenate()([upcnv4, cnv3b])
  icnv4 = downsample(128, 3, 1)(i4_in)      
  #disp4 = DISP_SCALING * downsample(1, 3, 1)(i4_in) + MIN_DISP
  disp4 = downsample(1,3,1)(i4_in)
  disp4_up = tf.image.resize(disp4, [np.int(H/4), np.int(W/4)], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  upcnv3 = upsample(64, 3, 2)(icnv4)
  i3_in = Concatenate()([upcnv3, cnv2b, disp4_up])      
  icnv3 = downsample(64, 3, 1)(i3_in)
  disp3 = DISP_SCALING * downsample(1, 3, 1)(i3_in) + MIN_DISP
  disp3_up = tf.image.resize(disp3, [np.int(H/2), np.int(W/2)], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  upcnv2 = upsample(32, 3, 2)(icnv3)
  i2_in = Concatenate()([upcnv2, cnv1b, disp3_up])      
  icnv2 = downsample(32, 3, 1)(i2_in)
  disp2 = DISP_SCALING * downsample(1, 3, 1)(i2_in) + MIN_DISP
  disp2_up = tf.image.resize(disp2, [np.int(H), np.int(W)], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  upcnv1 = upsample(16, 3, 2)(icnv2)
  i1_in = Concatenate()([upcnv1, disp2_up])      
  icnv1 = downsample(16, 3, 1)(i1_in)
  disp1 = DISP_SCALING * downsample(1, 3, 1)(i1_in) + MIN_DISP

  #pred_disp = [disp1, disp2, disp3, disp4]
  
  #pred_depth = [1./d for d in pred_disp]
  return tf.keras.Model(inputs = inputs, outputs = disp1 )
  #return disp4

In [None]:
LAMBDA = 100
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = keras.layers.Input(shape = [128, 416, 3], name = 'input_image')
  tar = keras.layers.Input(shape = [128, 416, 1], name = 'target_image')

  x = Concatenate()([inp, tar])

  down1 = downsample(32, 5, 2)(x)
  down2 = downsample(64, 3, 2)(down1)
  down3 = downsample(64, 3, 2)(down2)
  down4 = downsample(128, 3, 2)(down3)
  down5 = downsample(128, 3, 2)(down4)
  down6 = downsample(256, 3, 2)(down5)
  down7 = downsample(256, 3, 2)(down6)
  down8 = downsample(512, 3, 2)(down7)

  flat = Flatten()(down8)
  dense1 = Dense(512, activation = LeakyReLU())(flat)
  dense2 = Dense(256, activation = LeakyReLU())(dense1)
  dense3 = Dense(128, activation = LeakyReLU())(dense2)
  pred_head = Dense(1, activation = 'sigmoid')(dense3)

  return tf.keras.Model(inputs = [inp, tar], outputs = pred_head)
  #return pred_head
  


In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15,15))

  display_list = [test_input[0], tar[0].numpy().squeeze(), prediction[0].numpy().squeeze()]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

In [None]:
generator = Generator()
discriminator = Discriminator()

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
discriminator.summary()

In [None]:
for example_input, example_target in test_dataset.take(1):
  generate_images(generator, example_input, example_target)

In [None]:
import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
#@tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
  print(f"Epoch: {epoch} | gen_loss: {gen_total_loss} | disc_loss: {disc_loss} ")
  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)

In [None]:
def fit(train_ds, epochs, test_ds):
  for epoch in range(epochs):
    start = time.time()

    display.clear_output(wait=True)
    '''
    for example_input, example_target in test_ds.take(1):
      generate_images(generator, example_input, example_target)
    print("Epoch: ", epoch)
    '''
    # Train
    for n, (input_image, target) in train_ds.enumerate():
      print('.', end='')
      if (n+1) % 100 == 0:
        print()
      train_step(input_image, target, epoch)
    print()

    # saving (checkpoint) the model every 20 epochs
    if (epoch + 1) % 20 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
  checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
EPOCHS = 5

In [None]:
fit(train_dataset, EPOCHS, test_dataset)