In [None]:
# basis is the official pix2pix tutorial on the tensorflow website

In [None]:
dataset = "maps1800" # original, retiled, maps1800

In [None]:
import tensorflow as tf
#print(tf.__version__)

In [None]:
!pip install tensorflow==2.2.0

# Pix2Pix

## Import TensorFlow and other libraries

In [None]:
import tensorflow as tf

import os
import time

from matplotlib import pyplot as plt
from IPython import display

In [None]:
!pip install -U tensorboard

In [None]:
PATH = "/home/jonathan/CIL-street/data/"+dataset+"/all/"

In [None]:
PATH

In [None]:
BUFFER_SIZE = 512 # should not make a difference
BATCH_SIZE = 1 # increasing batch size did not speed up training significantly and caused OOM errors
IMG_WIDTH = 512
IMG_HEIGHT = 512

In [None]:
def load(image_file):
  print("load")
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image, channels=3) # need to be careful, apparently some images are saved with 1 channel and some with 3 channels
  input_image = tf.cast(image, tf.float32)
  return input_image

In [None]:
PATH_original = "/home/jonathan/CIL-street/data/original/training/"
path = PATH_original+'images/satImage_002.png'
path_gt = PATH_original+'groundtruth/satImage_002.png'
inp = tf.image.resize(load(path), (512, 512))
re = tf.image.resize(load(path_gt), (512, 512))

# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0, cmap='Greys_r')

In [None]:
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

In [None]:
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]

In [None]:
# normalizing the images to [-1, 1]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

In [None]:
@tf.function()
def random_jitter(input_image, real_image):

  rot = tf.random.uniform(
    (), minval=0, maxval=4, dtype=tf.dtypes.int32
    )
  input_image = tf.image.rot90(input_image, k = rot)
  real_image = tf.image.rot90(real_image, k = rot)

  input_image, real_image = resize(input_image, real_image, 286*2, 286*2)
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

As you can see in the images below
that they are going through random jittering
Random jittering as described in the paper is to

1. Resize an image to bigger height and width
2. Randomly crop to the target size
3. Randomly flip the image horizontally

In [None]:
plt.figure(figsize=(6, 6))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i+1)
  plt.imshow(rj_inp/255.0)
  plt.axis('off')
plt.show()

In [None]:
def load_image_train(image_file, image_file_gt):
  print(image_file)
  input_image = load(image_file)
  img_gray = load(image_file_gt)
  print(img_gray.shape)
  real_image = tf.image.grayscale_to_rgb(img_gray) if (img_gray.shape[2] == 1) else img_gray
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)
  real_image = tf.image.rgb_to_grayscale(real_image)
  
  return input_image, real_image

In [None]:
def load_image_test(image_file, image_file_gt):
  print("load image test")
  print(image_file)
  input_image = load(image_file)
  img_gray = load(image_file_gt)
  print(img_gray.shape)
  real_image = tf.image.grayscale_to_rgb(img_gray) if (img_gray.shape[2] == 1) else img_gray
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)
  real_image = tf.image.rgb_to_grayscale(real_image)

  return input_image, real_image

## Input Pipeline

In [None]:
print(PATH)
full_dataset_images = tf.data.Dataset.list_files(PATH+'images/*.png', shuffle=False)
full_dataset_gt = tf.data.Dataset.list_files(PATH+'groundtruth/*.png',shuffle=False)
full_dataset = tf.data.Dataset.zip((full_dataset_images, full_dataset_gt))
full_dataset = full_dataset.shuffle(BUFFER_SIZE)
print(sum(1 for _ in full_dataset))

train_dataset = full_dataset.take(1800) 
train_dataset = full_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(BATCH_SIZE)

test_dataset = full_dataset.skip(1800)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [None]:
'''TEST_PATH = '/home/jonathan/CIL-street/data/original/validation/'
test_dataset_images = tf.data.Dataset.list_files(TEST_PATH+'images/*.png', shuffle=False)
test_dataset_gt = tf.data.Dataset.list_files(TEST_PATH+'groundtruth/*.png',shuffle=False)
test_dataset = tf.data.Dataset.zip((test_dataset_images, test_dataset_gt))
print(test_dataset)
print(sum(1 for _ in test_dataset))
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)'''

## Build the Generator
  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net).


In [None]:
OUTPUT_CHANNELS = 1

In [None]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [None]:
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)

In [None]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)

In [None]:
!pip install -U image-classifiers
import tensorflow as tf
from tensorflow.keras.layers import *
from classification_models.tfkeras import Classifiers

def RoadNet(backbone_name='seresnext50', input_shape=(None, None, 3), encoder_weights='imagenet',
            encoder_freeze=False, predict_distance=False, predict_contour=False, aspp=False, se=False):
    """
    Encoder-decoder based architecture for road segmentation in aerial images.
    :param backbone_name: name of the backbone network. Supported backbones are ResNet50, ResNet101, SEResNet50,
                          SEResNet101, ResNeXt50, ResNeXt101, SEResNeXt50 and  SEResNeXt101.
    :param input_shape: input shape, where the first two dimensions need to be a multiple of 16.
    :param encoder_weights: name of dataset for which to load weights. Only ImageNet is supported.
    :param encoder_freeze: freezes the weights in the backbone save from batch normalization layers
    :param predict_distance: if true, adds an additional output predicting the distance map of the road mask
    :param predict_contour: if true, adds an additional output predicting the contour of the road mask
    :param aspp: if true, the encoder output is passed through an ASPP module. More info at
                 http://liangchiehchen.com/projects/DeepLab.html
    :param se: if true, enables Squeeze and Excitation on the decoder convolutional blocks. More info at
               https://arxiv.org/abs/1709.01507
    :return: a tf.keras instance of the model
    """

    decoder_filters = (256, 128, 64, 32, 16)
    n_blocks = len(decoder_filters)
    skip_layers_dict = {'seresnext50': (1078, 584, 254, 4), 'seresnext101': (2472, 584, 254, 4),
                        'seresnet101': (552, 136, 62, 4), 'seresnet50': (246, 136, 62, 4),
                        'resnext50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
                        'resnext101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
                        'resnet50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
                        'resnet101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0')}
    skip_layers = skip_layers_dict[backbone_name]

    # load backbone network from external library
    backbone_fn, _ = Classifiers.get(backbone_name)
    backbone = backbone_fn(input_shape=input_shape, weights=encoder_weights, include_top=False)
    skips = ([backbone.get_layer(name=i).output if isinstance(i, str)
              else backbone.get_layer(index=i).output for i in skip_layers])

    x = backbone.output

    # build ASPP if requested
    if aspp:
        b0 = GlobalAveragePooling2D()(x)
        b0 = Lambda(lambda x: tf.keras.backend.expand_dims(x, 1))(b0)
        b0 = Lambda(lambda x: tf.keras.backend.expand_dims(x, 1))(b0)
        b0 = Conv2D(256, (1, 1), padding='same', use_bias=False, name='aspp_pooling')(b0)
        b0 = BatchNormalization(name='aspp_pooling_bn')(b0)
        b0 = Activation('relu', name='aspp_pooling_relu')(b0)
        b0 = Lambda(lambda x : tf.image.resize(x, (12, 12)))(b0)

        b1 = Conv2D(256, 1, padding='same', dilation_rate=(1, 1), kernel_initializer='he_normal', name='aspp_b1_conv')(x)
        b1 = BatchNormalization(axis=3, name='aspp_b1_bn')(b1)
        b1 = Activation('relu', name='aspp_b1_relu')(b1)
        b2 = Conv2D(256, 3, padding='same', dilation_rate=(3, 3), kernel_initializer='he_normal', name='aspp_b2_conv')(x)
        b2 = BatchNormalization(axis=3, name='aspp_b2_bn')(b2)
        b2 = Activation('relu', name='aspp_b2_relu')(b2)
        b3 = Conv2D(256, 3, padding='same', dilation_rate=(6, 6), kernel_initializer='he_normal', name='aspp_b3_conv')(x)
        b3 = BatchNormalization(axis=3, name='aspp_b3_bn')(b3)
        b3 = Activation('relu', name='aspp_b3_relu')(b3)

        x = Concatenate(axis=3, name='aspp_concat')([b0, b1, b2, b3])
        x = Conv2D(256, (1, 1), padding='same', use_bias=False, name='aspp_concat_conv')(x)
        x = BatchNormalization(axis=3, name='aspp_concat_bn')(x)
        x = Activation('relu', name='aspp_concat_relu')(x)

    # create the decoder blocks sequentially
    for i in range(n_blocks):

        filters = decoder_filters[i]

        x = UpSampling2D(size=2, name='decoder_stage{}_upsample'.format(i))(x)
        # skip connection
        if i < len(skips):
            x = Concatenate(axis=3, name='decoder_stage{}_concat'.format(i))([x, skips[i]])

        x = Conv2D(filters=filters, kernel_size=3, padding='same', use_bias=False, kernel_initializer='he_uniform', name='decoder_stage{}a_conv'.format(i))(x)
        x = BatchNormalization(axis=3, name='decoder_stage{}a_bn'.format(i))(x)
        x = Activation('relu', name='decoder_stage{}a_activation'.format(i))(x)

        # Squeeze and Excitation on the first convolution
        if se:
            w = GlobalAveragePooling2D(name='decoder_stage{}a_se_avgpool'.format(i))(x)
            w = Dense(filters // 8, activation='relu', name='decoder_stage{}a_se_dense1'.format(i))(w)
            w = Dense(filters, activation='sigmoid', name='decoder_stage{}a_se_dense2'.format(i))(w)
            x = Multiply(name='decoder_stage{}a_se_mult'.format(i))([x, w])

        x = Conv2D(filters=filters, kernel_size=3, padding='same', use_bias=False, kernel_initializer='he_uniform', name='decoder_stage{}b_conv'.format(i))(x)
        x = BatchNormalization(axis=3, name='decoder_stage{}b_bn'.format(i))(x)
        x = Activation('relu', name='decoder_stage{}b_activation'.format(i))(x)

        # Squeeze and Excitation on the second convolution
        if se:
            w = GlobalAveragePooling2D(name='decoder_stage{}b_se_avgpool'.format(i))(x)
            w = Dense(filters // 8, activation='relu', name='decoder_stage{}b_se_dense1'.format(i))(w)
            w = Dense(filters, activation='sigmoid', name='decoder_stage{}b_se_dense2'.format(i))(w)
            x = Multiply(name='decoder_stage{}b_se_mult'.format(i))([x, w])

    task1 = Conv2D(filters=1, kernel_size=(3, 3), padding='same', kernel_initializer='glorot_uniform', name='final_conv_mask')(x)
    task1 = Activation('sigmoid', name='final_activation_mask')(task1)

    # prepare for Multitask Learning
    if predict_contour:
        task2 = Conv2D(filters=1, kernel_size=(3, 3), padding='same', kernel_initializer='glorot_uniform', name='final_conv_contour')(x)
        task2 = Activation('sigmoid', name='final_activation_contour')(task2)
    if predict_distance:
        task3 = Conv2D(filters=1, kernel_size=(3, 3), padding='same', kernel_initializer='glorot_uniform', name='final_conv_distance')(x)
        task3 = Activation('linear', name='final_activation_distance')(task3)

    if predict_contour and predict_distance:
        output = [task1, task2, task3]
    elif predict_contour:
        output = [task1, task2]
    elif predict_distance:
        output = [task1, task3]
    else:
        output = task1

    model = tf.keras.models.Model(backbone.input, output)

    # freeze encoder weights if requested
    if encoder_freeze:
        for layer in backbone.layers:
            if not isinstance(layer, tf.keras.layers.BatchNormalization):
                layer.trainable = False

    return model

In [None]:
generator = RoadNet(backbone_name='seresnext101', predict_contour=False, aspp=False, se=False) # see not learning, aspp no improvement
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
gen_output = generator(inp[tf.newaxis,...], training=False)
print(gen_output[0,...].shape)
plt.imshow(tf.reshape(gen_output[0,...], (512,512))/255.0, cmap='Greys_r')

In [None]:

def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  kld = 0 # save computation # tf.reduce_mean(tf.keras.losses.kullback_leibler_divergence(tf.ones_like(disc_generated_output), disc_generated_output))

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
  l2_loss = 0 # save computation # tf.reduce_mean((target - gen_output)*(target - gen_output))

  total_gen_loss = gan_loss + (50 * l1_loss) #+ kld #+ l2_loss

  return total_gen_loss, gan_loss, l1_loss, l2_loss, kld

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[512, 512, 3], name='input_image')
  target = tf.keras.layers.Input(shape=[512, 512, 1], name='target_image')

  x = tf.keras.layers.concatenate([inp, target]) # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
  down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
  down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
  # add more layers if the discriminator needs to be stronger
  #conv2 = tf.keras.layers.Conv2D(512, 4, strides=1,
  #                              kernel_initializer=initializer,
  #                              use_bias=False)(conv) # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, target], outputs=last)

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

In [None]:
disc_out = discriminator([inp[tf.newaxis,...], gen_output], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

## Define the Optimizers and Checkpoint-saver


In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-3, beta_1=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-7, beta_1=0.6) # low learning rate to avoid it beating the generator

In [None]:
checkpoint_dir = './training_checkpoints_new6'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Generate Images

Write a function to plot some images during training.

* We pass images from the test dataset to the generator.
* The generator will then translate the input image into the output.
* Last step is to plot the predictions and **voila!**

Note: The `training=True` is intentional here since
we want the batch statistics while running the model
on the test dataset. If we use training=False, we will get
the accumulated statistics learned from the training dataset
(which we don't want)

In [None]:
def generate_images(model, test_input, tar):
  print(test_input.shape)
  print("generate images")
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15,15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    denorm = display_list[i] * 0.5 + 0.5
    if(denorm.shape[2] == 1):
      plt.imshow(tf.reshape(denorm, (512,512))/255.0, cmap='Greys_r')
    else:
      plt.imshow(denorm)
    plt.axis('off')
  plt.show()

In [None]:
for example_input, example_target in train_dataset.take(3):
  print(example_input.shape)
  print(example_target.shape)
  generate_images(generator, example_input, example_target)

In [None]:
for example_input, example_target in test_dataset.take(10):
  generate_images(generator, example_input, example_target)

In [None]:
import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs #alternatively run it on the command line

In [None]:
@tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss, gen_l2_loss, gen_kld_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('gen_l2_loss', gen_l2_loss, step=epoch)
    tf.summary.scalar('gen_kld_loss', gen_kld_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)

In [None]:
def fit(train_ds, start_epoch, epochs, test_ds):
  for epoch in range(start_epoch, epochs):
    start = time.time()

    display.clear_output(wait=True)

    for example_input, example_target in test_ds.take(1):
      generate_images(generator, example_input, example_target)
    print("Epoch: ", epoch)

    # Train
    for n, (input_image, target) in train_ds.enumerate():
      print('.', end='')
      if (n+1) % 100 == 0:
        print()
      train_step(input_image, target, epoch)
    print()

    # saving (checkpoint) the model every 5 epochs
    if (epoch + 1) % 5 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
  checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
continueTraining = False
if(continueTraining):
  checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) # continue training

In [None]:
fit(train_dataset, 0, 60, test_dataset)

## Restore the latest checkpoint and test

In [None]:
!ls {checkpoint_dir}

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Generate using test dataset

In [None]:
# Run the trained model on a few examples from the test dataset
for inp, tar in test_dataset.take(5):
  generate_images(generator, inp, tar)

In [None]:
def load_test_images(image_file):
  print("load image test")
  print(image_file)
  input_image = load(image_file)
  print(input_image.shape)
  input_image = (input_image / 127.5) - 1
  return input_image

In [None]:
PATH_test_images = "/home/jonathan/CIL-street/data/test_images/"
dataset_test_images_names = tf.data.Dataset.list_files(PATH_test_images+'*.png', shuffle=False)
print(sum(1 for _ in dataset_test_images_names)) # should be 94
dataset_test_images = dataset_test_images_names.map(load_test_images)
dataset_test_images = dataset_test_images.batch(BATCH_SIZE)


In [None]:
import numpy
import cv2

for ind, (input_image, filename) in enumerate(zip(dataset_test_images.take(94), dataset_test_images_names.take(94))):
  print(filename)
  filename = filename.numpy()
  filename = os.path.basename(str(filename)).split('.')[0]+'.png'

  input_image0 = tf.image.crop_to_bounding_box(
    input_image, 0, 0, 400, 400
    )
  input_image1 = tf.image.crop_to_bounding_box(
    input_image, 208, 0, 400, 400
    )
  input_image2 = tf.image.crop_to_bounding_box(
    input_image, 0, 208, 400, 400
    )
  input_image3 = tf.image.crop_to_bounding_box(
    input_image, 208, 208, 400, 400
    )
  
  patched = numpy.zeros((608,608,3))
  patches = []
  for input in [input_image0, input_image1, input_image2, input_image3]:
    input = tf.image.resize(input, [512, 512],
                                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    print(filename)
    prediction = generator(input, training=True)
    plt.figure(figsize=(15,15))

    title = ['Input Image', 'Predicted Image']
    input0 = input[0] * 0.5 + 0.5
    prediction0 = prediction[0] * 0.5 + 0.5
    prediction0 = tf.image.resize(prediction0, (400, 400))
    combined = prediction0

    plt.subplot(1, 6, 1)
    plt.title(title[0])
    plt.imshow(input0)
    plt.subplot(1, 6, 2)
    plt.imshow(tf.reshape(prediction0, (400,400)), cmap='Greys_r')

    for i in range(2, 5):
      input_rot = tf.image.rot90(
        input, k=(i-1)
      )
      prediction = generator(input_rot, training=True) # seems no difference if training or not
      plt.subplot(1, 6, i+1)
      pred_norm = prediction[0] * 0.5 + 0.5
      pred_norm = tf.image.rot90(
        pred_norm, k=-(i-1)
      )
      pred_norm = tf.image.resize(pred_norm, (400, 400))
      plt.imshow(tf.reshape(pred_norm, (400,400)), cmap='Greys_r')
      combined += pred_norm

    plt.subplot(1, 6, 6)
    plt.imshow(tf.reshape(combined, (400,400)), cmap='Greys_r')
    patches.append(combined)

    plt.show()

  patches[0] = tf.image.grayscale_to_rgb(patches[0] - 2)
  patches[1] = tf.image.grayscale_to_rgb(patches[1] - 2)
  patches[2] = tf.image.grayscale_to_rgb(patches[2] - 2)
  patches[3] = tf.image.grayscale_to_rgb(patches[3] - 2)
  patched[0:400, 0:400] += tf.image.resize(patches[0], [400, 400], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 
  patched[208:608, 0:400] += tf.image.resize(patches[1], [400, 400], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 
  patched[0:400, 208:608] += tf.image.resize(patches[2], [400, 400], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 
  patched[208:608, 208:608] += tf.image.resize(patches[3], [400, 400], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 
  patched[208:400, 0:208] /= 2 # center top
  patched[208:400, 400:608] /= 2 # center bottom
  patched[0:208, 208:400] /= 2 # center left
  patched[400:608, 208:400] /= 2 # center right
  patched[208:400, 208:400] /= 4

  plt.imshow(input_image[0] * 0.5 + 0.5)
  plt.show()

  patched = tf.clip_by_value(
    patched, 0, 1
  )

  # another test
  # erode dilate
  kernel = numpy.ones((10,10),numpy.uint8)
  patched = patched.numpy()
  #patched = cv2.morphologyEx(patched, cv2.MORPH_OPEN, kernel)
  ret, patched = cv2.threshold(patched,0.25,1.0,cv2.THRESH_BINARY)
  #patched = cv2.morphologyEx(patched, cv2.MORPH_CLOSE, kernel)
  #patched = cv2.morphologyEx(patched, cv2.MORPH_CLOSE, kernel)

  plt.imshow(patched)
  plt.show()

  pred_uint = tf.image.convert_image_dtype(patched, tf.uint16)
  enc = tf.image.encode_png(pred_uint)
  fname = tf.constant("/home/jonathan/cil_pix2pix/"+filename)
  fwrite = tf.io.write_file(fname, enc)


In [None]:
# alternative but worse results

'''
import numpy
for ind, (input_image, filename) in enumerate(zip(dataset_test_images.take(94), dataset_test_images_names.take(94))):
  print(filename)
  filename = filename.numpy()
  filename = os.path.basename(str(filename)).split('.')[0]+'.png'

  input = tf.image.resize(input_image, [512, 512],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  print(filename)
  prediction = generator(input, training=False)
  plt.figure(figsize=(15,15))

  title = ['Input Image', 'Predicted Image']
  input0 = input[0] * 0.5 + 0.5
  prediction0 = prediction[0] * 0.5 + 0.5
  prediction0 = tf.image.resize(prediction0, (608, 608))
  combined = prediction0

  plt.subplot(1, 6, 1)
  plt.title(title[0])
  plt.imshow(input0)
  plt.subplot(1, 6, 2)
  plt.imshow(tf.reshape(prediction0, (608,608)), cmap='Greys_r')

  for i in range(2, 5):
    input_rot = tf.image.rot90(
      input, k=(i-1)
    )
    prediction = generator(input_rot, training=False)
    plt.subplot(1, 6, i+1)
    pred_norm = prediction[0] * 0.5 + 0.5
    pred_norm = tf.image.rot90(
      pred_norm, k=-(i-1)
    )
    pred_norm = tf.image.resize(pred_norm, (608, 608))
    plt.imshow(tf.reshape(pred_norm, (608,608)), cmap='Greys_r')
    combined += pred_norm

  plt.subplot(1, 6, 6)
  plt.imshow(tf.reshape(combined, (608,608)), cmap='Greys_r')
  plt.show()

  combined = tf.image.grayscale_to_rgb(combined - 2)
  patched = tf.image.resize(combined*255, [608, 608], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 
  
  plt.imshow(input_image[0] * 0.5 + 0.5)
  plt.show()

  patched = tf.clip_by_value(
    patched, 0, 1
  )
  plt.imshow(patched)
  plt.show()

  pred_uint = tf.image.convert_image_dtype(patched, tf.uint16)
  enc = tf.image.encode_png(pred_uint)
  fname = tf.constant("/home/jonathan/cil_pix2pix/"+filename)
  fwrite = tf.io.write_file(fname, enc)

'''