#Pix2Pix
This trains a CAN that's closely related to pix2pix. Details can be found in our report.

## Imports & Setup

In [None]:
import tensorflow as tf
 
import os
import time
import random
import numpy as np
import glob
import cv2
 
from matplotlib import pyplot as plt
from IPython import display

from google.colab import drive

In [None]:
!pip install -U tensorboard

### Mount Data Location
Mounts the Gdrive, where the raw prediction output of the models resides.

In [None]:
drive.mount('/content/drive')

PATH = "/content/drive/My Drive/CIL Project Images"
print(PATH)

### Global Parameters

In [None]:
BUFFER_SIZE = 600
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
 
NUM_VALIDATION_IMAGES = 6 * 6

OUTPUT_CHANNELS = 1

## Input Pipeline

Helper functions for Input Pipeline

In [None]:
def augment(image, label, resize=400, scale=1, horizontal_flip=True, random_crop=True, rotate90=True, color=True, block_noise=True):
    temp = tf.concat([image, label], axis=-1)
    if resize is not None:
        temp = tf.image.resize(temp, [resize, resize])
    if random_crop:
        temp = tf.image.random_crop(temp, size=[256, 256, 4])
        
    if horizontal_flip:        temp = tf.image.random_flip_left_right(temp)
    if rotate90:
        temp = tf.image.rot90(temp, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
        
    temp = tf.image.resize(temp, [IMG_WIDTH, IMG_HEIGHT])
 
    image, label = tf.split(temp, num_or_size_splits=[3, 1], axis=-1)
    
    if color:
        image = tf.image.random_hue(image, 0.125)
        image = tf.image.random_saturation(image, 0.5, 1.5)
        image = tf.image.random_brightness(image, 0.075)
        image = tf.image.random_contrast(image, 0.7, 1.3)
 
        color_ordering = random.randint(0,3)
 
        if color_ordering == 0:
          image = tf.image.random_brightness(image, max_delta=32. / 255.)
          image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
          image = tf.image.random_hue(image, max_delta=0.2)
          image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        elif color_ordering == 1:
          image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
          image = tf.image.random_brightness(image, max_delta=32. / 255.)
          image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
          image = tf.image.random_hue(image, max_delta=0.2)
        elif color_ordering == 2:
          image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
          image = tf.image.random_hue(image, max_delta=0.2)
          image = tf.image.random_brightness(image, max_delta=32. / 255.)
          image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        elif color_ordering == 3:
          image = tf.image.random_hue(image, max_delta=0.2)
          image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
          image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
          image = tf.image.random_brightness(image, max_delta=32. / 255.)
 
 
      if block_noise:
        #https://stackoverflow.com/questions/57374732/how-to-apply-imgaug-augmentation-to-tf-datadataset-in-tensorflow-2-0
        image_shape = image.shape
        [image,] = tf.py_function(apply_gaussian_block, [image], [tf.float32])
        image.set_shape(image_shape)
      
 
    return image, label
 
 
def augmentAutomatically(images):
    augmentation_size = len(images)
    datagen = ImageDataGenerator(
        width_shift_range=0.1,
        height_shift_range=0.1,
        rotation_range=360,
        shear_range=0.1,
        zoom_range=0.2,
        fill_mode="reflect"
    )
 
    seed = np.random.choice(100000, 1)
    d_images = datagen.flow(images, batch_size=augmentation_size, shuffle=False, seed=seed)
 
    images_augmented = next(d_images)
 
    return images_augmented
 
def augment_validation(image, label):
    temp = tf.concat([image, label], axis=-1)
    temp = tf.image.random_flip_left_right(temp)
    #not necessary if there is also rotation
    #temp = tf.image.random_flip_up_down(temp)
    temp = tf.image.rot90(temp, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
    image, label = tf.split(temp, num_or_size_splits=[3, 1], axis=-1)
 
    return image, label
 
def augment_validation_with_rand_crop(image, label):
    temp = tf.concat([image, label], axis=-1)
    temp =  tf.image.random_crop(temp, size=[256, 256, 4])
    temp = tf.image.random_flip_left_right(temp)
    #not necessary if there is also rotation
    #temp = tf.image.random_flip_up_down(temp)
    temp = tf.image.rot90(temp, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
    image, label = tf.split(temp, num_or_size_splits=[3, 1], axis=-1)
 
    return image, label
 
def apply_gaussian_block(image, size_bounds = [60, 100]):
    w, h, _ = image.shape
    size = random.randint(size_bounds[0], size_bounds[1])
 
    gauss_image = get_gaussian_noise_block_image(size = size)
    x_offset = random.randint(0, w - size)
    y_offset =  random.randint(0, h - size)

    newIm = image.numpy()
    newIm[y_offset:y_offset+gauss_image.shape[0], x_offset:x_offset+gauss_image.shape[1]] = gauss_image
    newIm = tf.convert_to_tensor(newIm, dtype=tf.float32)
    return newIm
 
 
def get_gaussian_noise_block_image(size):
    mean = 0
    var = 0.1
    sigma = var**0.5
    gauss = np.random.normal(mean,sigma,(size,size,3))
    gauss = gauss.reshape(size,size,3)
    gauss = cv2.normalize(gauss, None, alpha = 0, beta = 1, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
 
    return gauss
 
  
def resize_validation(image, label):
  return tf.image.resize(image, [IMG_WIDTH, IMG_HEIGHT]),  tf.image.resize(label, [IMG_WIDTH, IMG_HEIGHT])
 
def resize_test(image):
    return tf.image.resize(image, [IMG_WIDTH, IMG_HEIGHT])
 
def crop_test_0(image):
    #3x3: patch[0, 0]
    #5x5: patch[0, 0]
    return tf.image.crop_to_bounding_box(image, 0, 0, 256, 256)
 
def crop_test_1(image):
    #3x3: patch[0, 1]
    #5x5: patch[0, 2]
    return tf.image.crop_to_bounding_box(image, 0, 176, 256, 256)
 
def crop_test_2(image):
    #3x3: patch[0, 2]
    #5x5: patch[0, 4]
    return tf.image.crop_to_bounding_box(image, 0, 352, 256, 256)
 
def crop_test_3(image):
    #3x3: patch[1, 0]
    #5x5: patch[2, 0]
    return tf.image.crop_to_bounding_box(image, 176, 0, 256, 256)
 
def crop_test_4(image):
    #3x3: patch[1, 1]
    #5x5: patch[2, 2]
    return tf.image.crop_to_bounding_box(image, 176, 176, 256, 256)
 
def crop_test_5(image):
    #3x3: patch[1, 2]
    #5x5: patch[2, 4]
    return tf.image.crop_to_bounding_box(image, 176, 352, 256, 256)
 
def crop_test_6(image):
    #3x3: patch[2, 0]
    #5x5: patch[4, 0]
    return tf.image.crop_to_bounding_box(image, 352, 0, 256, 256)
 
def crop_test_7(image):
    #3x3: patch[2, 1]
    #5x5: patch[4, 2]
    return tf.image.crop_to_bounding_box(image, 352, 176, 256, 256)
 
def crop_test_8(image):
    #3x3: patch[2, 2]
    #5x5: patch[4, 2]
    return tf.image.crop_to_bounding_box(image, 352, 352, 256, 256)
 
def crop_test5x5_01(image):
    #5x5: patch[0, 1]
    return tf.image.crop_to_bounding_box(image, 0, 88, 256, 256)
 
def crop_test5x5_12(image):
    #5x5: patch[0, 2]
    return tf.image.crop_to_bounding_box(image, 0, 264, 256, 256)
 
def crop_test5x5_03(image):
    #5x5: patch[1, 0]
    return tf.image.crop_to_bounding_box(image, 88, 0, 256, 256)
 
def crop_test5x5_04(image):
    #5x5: patch[1, 1]
    return tf.image.crop_to_bounding_box(image, 88, 88, 256, 256)
 
def crop_test5x5_14(image):
    #5x5: patch[1, 2]
    return tf.image.crop_to_bounding_box(image, 88, 176, 256, 256)
 
def crop_test5x5_24(image):
    #5x5: patch[1, 3]
    return tf.image.crop_to_bounding_box(image, 88, 264, 256, 256)
 
def crop_test5x5_25(image):
    #5x5: patch[1, 4]
    return tf.image.crop_to_bounding_box(image, 88, 352, 256, 256)
 
def crop_test5x5_34(image):
    #5x5: patch[2, 1]
    return tf.image.crop_to_bounding_box(image, 176, 88, 256, 256)
 
def crop_test5x5_45(image):
    #5x5: patch[2, 3]
    return tf.image.crop_to_bounding_box(image, 176, 264, 256, 256)
 
def crop_test5x5_36(image):
    #5x5: patch[3, 0]
    return tf.image.crop_to_bounding_box(image, 264, 0, 256, 256)
 
def crop_test5x5_37(image):
    #5x5: patch[3, 1]
    return tf.image.crop_to_bounding_box(image, 264, 88, 256, 256)
 
def crop_test5x5_47(image):
    #5x5: patch[3, 2]
    return tf.image.crop_to_bounding_box(image, 264, 176, 256, 256)
 
def crop_test5x5_48(image):
    #5x5: patch[3, 3]
    return tf.image.crop_to_bounding_box(image, 264, 264, 256, 256)
 
def crop_test5x5_58(image):
    #5x5: patch[3, 4]
    return tf.image.crop_to_bounding_box(image, 264, 352, 256, 256)
 
def crop_test5x5_67(image):
    #5x5: patch[4, 1]
    return tf.image.crop_to_bounding_box(image, 352, 88, 256, 256)
 
def crop_test5x5_78(image):
    #5x5: patch[4, 2]
    return tf.image.crop_to_bounding_box(image, 352, 264, 256, 256)
 
def augment_test_time_1(image):
    image = tf.image.rot90(image, 1)
    return image
 
def augment_test_time_2(image):
    image = tf.image.rot90(image, 2)
    return image
 
def augment_test_time_3(image):
    image = tf.image.rot90(image, 3)
    return image
 
def augment_test_time_4(image):
    image = tf.image.flip_left_right(image)
    return image
 
def augment_test_time_5(image):
    image = tf.image.flip_left_right(image)
    image = tf.image.rot90(image, 1)
    return image
 
def augment_test_time_6(image):
    image = tf.image.flip_left_right(image)
    image = tf.image.rot90(image, 2)
    return image
 
def augment_test_time_7(image):
    image = tf.image.flip_left_right(image)
    image = tf.image.rot90(image, 3)
    return image

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_png(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
 
  #return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
  return img
 
def decode_mask(mask):
  # convert the compressed string to a 3D uint8 tensor
  mask = tf.image.decode_png(mask, channels=1)
  ## Use `convert_image_dtype` to convert to floats in the [0,1] range.
  #mask = tf.image.convert_image_dtype(mask, tf.float32)
  
  # given masks images do have values > 0 and < 255
  mask = tf.cast(mask > 127, tf.float32)
 
  #return tf.image.resize(mask, [IMG_WIDTH, IMG_HEIGHT])
  return mask
 
 
def process_path(file_path, old, new):
  file_path_mask = tf.strings.regex_replace(file_path, old, new)
  mask = tf.io.read_file(file_path_mask)
  mask = decode_mask(mask)
  
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
 
  return img, mask
 
 
def test_process_path(file_path):
  
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  
  return img


def normalize(input_image, real_image=None):
  input_image = (input_image / 0.5) - 1
  if real_image is not None:
    real_image = (real_image / 0.5) - 1
    return input_image, real_image
  else:
    return input_image

### Generate Datasets
Create the train, validation and test dataset.

In [None]:
training_images_path = PATH + '/images_rotated/'
test_images_path = PATH + '/test_images/'

AUTOTUNE = tf.data.experimental.AUTOTUNE

list_ds = tf.data.Dataset.list_files(training_images_path + '*')
test_list_ds = tf.data.Dataset.list_files(test_images_path + '*', shuffle=False)
train =    list_ds.skip(NUM_VALIDATION_IMAGES).map(lambda x: process_path(x, "images_rotated", "groundtruth_rotated"), num_parallel_calls=AUTOTUNE)
train_dataset = train.cache().shuffle(BUFFER_SIZE).map(augment, num_parallel_calls=AUTOTUNE).map(normalize).batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validation_dataset = list_ds.take(NUM_VALIDATION_IMAGES).map(lambda x: process_path(x, "images_rotated", "groundtruth_rotated"), num_parallel_calls=AUTOTUNE).map(augment_validation_with_rand_crop).map(normalize).batch(BATCH_SIZE)
test_dataset = test_list_ds.map(test_process_path).map(normalize)

DEBUG/TEST code for input pipeline

In [None]:
for img, msk in train_dataset.take(1):
  print(tf.reduce_max(img))
  print(tf.reduce_min(img))
  print(tf.reduce_max(msk))
  print(tf.reduce_min(msk))
  plt.imshow(tf.squeeze(img) *0.5+0.5)
  plt.show()
 
for img, msk in validation_dataset.take(1):
  print(tf.reduce_max(img))
  print(tf.reduce_min(img))
  print(tf.reduce_max(msk))
  print(tf.reduce_min(msk))
  plt.imshow(tf.squeeze(img) *0.5+0.5)
 

## Generator
  * The architecture of generator is a modified U-Net.
  * Filter sizes are halfed
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder


In [None]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [None]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[256,256,3])

  down_stack = [
    downsample(32, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
    downsample(64, 4), # (bs, 64, 64, 128)
    downsample(128, 4), # (bs, 32, 32, 256)
    downsample(256, 4), # (bs, 16, 16, 512)
    downsample(256, 4), # (bs, 8, 8, 512)
    downsample(256, 4), # (bs, 4, 4, 512)
    downsample(256, 4), # (bs, 2, 2, 512)
    downsample(256, 4), # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(256, 4, apply_dropout=True), # (bs, 2, 2, 1024)
    upsample(256, 4, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(256, 4, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(256, 4), # (bs, 16, 16, 1024)
    upsample(128, 4), # (bs, 32, 32, 512)
    upsample(64, 4), # (bs, 64, 64, 256)
    upsample(32, 4), # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Run to see a visualization of the Generator

In [None]:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

### Generator loss
  * It is a sigmoid cross entropy loss of the generated images and an array of ones.
  * Also L1 loss is applied between the generated image and the target image to generate images that are structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 200. The original paper uses LAMBDA = 100 but for our task 200 showed to produce better results.

The training procedure for the generator is shown below:

In [None]:
LAMBDA = 200

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs((target*0.5+0.5) - (gen_output*0.5+0.5)))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

## Discriminator
  * The Discriminator is a PatchGAN.
  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 1], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

  down1 = downsample(32, 4, False)(x) # (bs, 128, 128, 64)
  down2 = downsample(64, 4)(down1) # (bs, 64, 64, 128)
  down3 = downsample(128, 4)(down2) # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(256, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

Run to plot the discriminator:

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

###Discriminator loss
  * The discriminator loss function takes 2 inputs -> real images and generated images
  * Calculates the sigmoid cross entropy for boths
  * Then the total_loss is the sum of real_loss and the generated_loss


In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
 
  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
 
  total_disc_loss = real_loss + generated_loss
 
  # return total_disc_loss
  return total_disc_loss 

## Define the Optimizers and Checkpoint-saver


In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-5, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-5, beta_1=0.5)

In [None]:
checkpoint_dir = '/content/drive/My Drive/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Generate Images


Note: The `training=True` is intentional here since
we want the batch statistics while running the model
on the test dataset. If we use training=False, we will get
the accumulated statistics learned from the training dataset
(which we don't want)

In [None]:
def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15,15))
  
  # getting the pixel values between [0, 1] to plot it. * 0.5 + 0.5
  display_list = [test_input[0] * 0.5 + 0.5, tf.squeeze(tar[0]) * 0.5 + 0.5, tf.squeeze(prediction[0])* 0.5 + 0.5]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']
 
 
  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i], cmap="gray" )
    plt.axis('off')
  plt.show()

DEBUG/TEST function

In [None]:
for example_input, example_target in validation_dataset.take(1):
  generate_images(generator, example_input, example_target)

## Training


In [None]:
EPOCHS = 600

In [None]:
import datetime
log_dir="logs/"
summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
@tf.function
def train_step(input_image, target, epoch, update_discriminator=True):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)
 
    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)
 
    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
 
  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)
 
  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  if update_discriminator:
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
 
  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)

* Iterates over the number of epochs.
* On each epoch it runs `generate_images` to show it's progress.
* It saves a checkpoint every 20 epochs.

In [None]:
def fit(train_ds, epochs, test_ds):
  for epoch in range(epochs):
    start = time.time()
 
    display.clear_output(wait=True)
 
    for example_input, example_target in test_ds.take(1):
      generate_images(generator, example_input, example_target)
    print("Epoch: ", epoch)
 
    # Train
    for n, (input_image, target) in train_ds.enumerate():
      print('.', end='')
      if (n+1) % 100 == 0:
        print()
      if n%4 == 0:
        train_step(input_image, target, epoch, update_discriminator=True)
      else:
        train_step(input_image, target, epoch, update_discriminator=False)
    print()
 
    # saving (checkpoint) the model every 20 epochs
    if (epoch + 1) % 20 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)
 
    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
  checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
#docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir {log_dir}

Executre this to start the training:

In [None]:
fit(train_dataset, EPOCHS, validation_dataset)

## Restore checkpoint

In [None]:
# !ls {checkpoint_dir}
print(checkpoint_dir)

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Predict on Test Dataset

Check if model load was successfuly by running it on the validation dataset.

In [None]:
# Run the trained model on the validation dataset
for inp, tar in validation_dataset.take(1):
  generate_images(generator, inp, tar)

In [None]:
def make_prediction(input, verbose=1):
  predictions = []
  for el in input:
    predictions.append(tf.squeeze(generator(el, training=True), [0]))
  return predictions

Use Test Time augmentation as described in the report on 5x5 patches for the predictions.

In [None]:
NUM_TEST = 94

filenames = sorted(os.listdir(test_images_path))
predictions = []
for idx in range (NUM_TEST):
  predictions.append(tf.zeros([608, 608, 1], tf.float32))
print("Start: predict on patch [0, 0] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_0).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_0).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_0).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_0).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_0).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_0).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_0).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_0).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 1.0), tf.fill([256, 88, 1], 0.75), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 1.0), tf.fill([88, 256, 1], 0.75), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[0, 352], [0, 352], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [0, 1] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_01).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_01).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_01).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_01).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_01).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_01).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_01).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_01).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 1.0), tf.fill([88, 256, 1], 0.75), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[0, 352], [88, 264], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [0, 2] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_1).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_1).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_1).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_1).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_1).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_1).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_1).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 96, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 1.0), tf.fill([88, 256, 1], 0.75), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[0, 352], [176, 176], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [0, 3] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_12).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_12).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_12).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_12).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_12).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_12).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_12).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_12).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 88, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 1.0), tf.fill([88, 256, 1], 0.75), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[0, 352], [264, 88], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [0, 4] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_2).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_2).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_2).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_2).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_2).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_2).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_2).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.75), tf.fill([256, 88, 1], 1.0)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 1.0), tf.fill([88, 256, 1], 0.75), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[0, 352], [352, 0], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [1, 0] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_03).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_03).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_03).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_03).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_03).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_03).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_03).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_03).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 1.0), tf.fill([256, 88, 1], 0.75), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[88, 264], [0, 352], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [1, 1] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_04).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_04).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_04).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_04).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_04).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_04).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_04).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_04).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[88, 264], [88, 264], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [1, 2] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_14).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_14).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_14).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_14).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_14).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_14).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_14).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_14).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 96, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[88, 264], [176, 176], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [1, 3] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_24).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_24).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_24).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_24).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_24).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_24).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_24).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_24).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 88, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[88, 264], [264, 88], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [1, 4] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_25).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_25).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_25).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_25).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_25).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_25).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_25).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_25).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.75), tf.fill([256, 88, 1], 1.0)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([88, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[88, 264], [352, 0], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [2, 0] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_3).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_3).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_3).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_3).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_3).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_3).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_3).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 1.0), tf.fill([256, 88, 1], 0.75), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([96, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[176, 176], [0, 352], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [2, 1] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_34).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_34).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_34).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_34).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_34).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_34).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_34).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_34).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([96, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[176, 176], [88, 264], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [2, 2] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_4).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_4).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_4).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_4).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_4).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_4).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_4).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 96, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([96, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[176, 176], [176, 176], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [2, 3] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_45).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_45).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_45).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_45).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_45).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_45).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_45).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_45).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 88, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([96, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[176, 176], [264, 88], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [2, 4] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_5).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_5).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_5).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_5).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_5).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_5).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_5).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.75), tf.fill([256, 88, 1], 1.0)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([96, 256, 1], 0.5), tf.fill([80, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[176, 176], [352, 0], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [3, 0] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_36).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_36).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_36).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_36).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_36).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_36).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_36).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_36).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 1.0), tf.fill([256, 88, 1], 0.75), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([88, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[264, 88], [0, 352], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [3, 1] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_37).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_37).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_37).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_37).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_37).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_37).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_37).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_37).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([88, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[264, 88], [88, 264], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [3, 2] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_47).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_47).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_47).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_47).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_47).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_47).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_47).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_47).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 96, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([88, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[264, 88], [176, 176], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [3, 3] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_48).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_48).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_48).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_48).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_48).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_48).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_48).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_48).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 88, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([88, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[264, 88], [264, 88], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [3, 4] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_58).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_58).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_58).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_58).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_58).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_58).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_58).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_58).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.75), tf.fill([256, 88, 1], 1.0)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.5), tf.fill([88, 256, 1], 0.25)]))
    prediction = tf.pad(prediction_patch, [[264, 88], [352, 0], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [4, 0] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_6).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_6).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_6).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_6).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_6).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_6).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_6).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 1.0), tf.fill([256, 88, 1], 0.75), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.75), tf.fill([88, 256, 1], 1.0)]))
    prediction = tf.pad(prediction_patch, [[352, 0], [0, 352], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [4, 1] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_67).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_67).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_67).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_67).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_67).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_67).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_67).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_67).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 88, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.75), tf.fill([88, 256, 1], 1.0)]))
    prediction = tf.pad(prediction_patch, [[352, 0], [88, 264], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [4, 2] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_7).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_7).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_7).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_7).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_7).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_7).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_7).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_7).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 96, 1], 0.5), tf.fill([256, 80, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.75), tf.fill([88, 256, 1], 1.0)]))
    prediction = tf.pad(prediction_patch, [[352, 0], [176, 176], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [4, 3] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_78).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_78).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_78).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_78).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_78).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_78).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_78).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test5x5_78).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.5), tf.fill([256, 88, 1], 0.25)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.75), tf.fill([88, 256, 1], 1.0)]))
    prediction = tf.pad(prediction_patch, [[352, 0], [264, 88], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")
print("Start: predict on patch [4, 4] using 8 transformations as test time augmentation")
predictions_patch = []
predictions_patch.append(make_prediction(test_dataset.map(crop_test_8).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_8).map(augment_test_time_1).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_8).map(augment_test_time_2).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_8).map(augment_test_time_3).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_8).map(augment_test_time_4).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_8).map(augment_test_time_5).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_8).map(augment_test_time_6).batch(1), verbose=1))
predictions_patch.append(make_prediction(test_dataset.map(crop_test_8).map(augment_test_time_7).batch(1), verbose=1))
predictions_patch = tf.stack(predictions_patch, axis=1)
for idx in range (NUM_TEST):
    prediction_patch = tf.stack([predictions_patch[idx][0],
            tf.image.rot90(predictions_patch[idx][1], 3),
            tf.image.rot90(predictions_patch[idx][2], 2),
            tf.image.rot90(predictions_patch[idx][3], 1),
            tf.image.flip_left_right(predictions_patch[idx][4]),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][5], 3)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][6], 2)),
            tf.image.flip_left_right(tf.image.rot90(predictions_patch[idx][7], 1))])
    prediction_patch = tf.math.reduce_mean(prediction_patch, axis=0)
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=1, values=[
        tf.fill([256, 80, 1], 0.25), tf.fill([256, 88, 1], 0.75), tf.fill([256, 88, 1], 1.0)]))
    prediction_patch = tf.math.multiply(prediction_patch, tf.concat(axis=0, values=[
        tf.fill([80, 256, 1], 0.25), tf.fill([88, 256, 1], 0.75), tf.fill([88, 256, 1], 1.0)]))
    prediction = tf.pad(prediction_patch, [[352, 0], [352, 0], [0, 0]], "CONSTANT")
    predictions[idx] = tf.math.add(predictions[idx], prediction)
print("Done")


Check output images for correctness:

In [None]:
value_test = predictions[0] * 0.5 + 0.5
plt.imshow(tf.squeeze(value_test))
plt.show()

print(tf.reduce_max(value_test))
print(tf.reduce_min(value_test))

##Save predictions
The code writes the predictions to disk, including the submission csv file for kaggle.

NOTE: Run the cell directly below containing the helper function for the submission format before!


In [None]:
#write to disk
print("Start: write predicted images to ../testing/predictions")
for idx in range(NUM_TEST):
    prediction = predictions[idx]
    
    #back to 0-1 range
    prediction = prediction * 0.5 + 0.5

    prediction = tf.image.convert_image_dtype(prediction, tf.uint8)
    img_prediction = tf.image.encode_png(prediction)
    plt.imshow(tf.squeeze(prediction))
    plt.show()
    plt.axis('off')
    number = filenames[idx]
    tf.io.write_file(PATH + '/predictions_michi/' + str(number[5:8]) + ".png", img_prediction)
print("Done")
import datetime
print("Start: create kaggle submission file") 
submit_predictions(filenames)
print("Done")

In [None]:
import re
import matplotlib.image as mpimg
import numpy as np
foreground_threshold = 0.35 # percentage of pixels > 1 required to assign a foreground label to a patch

# assign a label to a patch
def patch_to_label(patch):
    df = np.mean(patch)
    if df > foreground_threshold:
        return 1
    else:
        return 0

def mask_to_submission_strings(image_filename):
    """Reads a single image and outputs the strings that should go into the submission file"""
    img_number = int(re.search(r"\d+", image_filename).group(0))
    im = mpimg.imread(image_filename)
    patch_size = 16
    for j in range(0, im.shape[1], patch_size):
        for i in range(0, im.shape[0], patch_size):
            patch = im[i:i + patch_size, j:j + patch_size]
            label = patch_to_label(patch)
            yield("{:03d}_{}_{},{}".format(img_number, j, i, label))


def masks_to_submission(submission_filename, *image_filenames):
    """Converts images into a submission file"""
    with open(submission_filename, 'w') as f:
        f.write('id,prediction\n')
        for fn in image_filenames[0:]:
            f.writelines('{}\n'.format(s) for s in mask_to_submission_strings(fn))

#### Applies the "mask_to_submission file that converts the predicted images to our output format 
#   Output format: Each image is split into patches of 16 x 16 pixels, and then a 0 or 1 label is assigned to it 
#   based on our predicted pixel-wise label
#   The public test score is based on those patch-wise predictions 
#  ####

def submit_predictions(filenames, path=PATH + '/predictions_michi/', thresholding = False):
  # Path(path + "/results/csv").mkdir(parents=True, exist_ok=True)
  submission_filename = path + 'submission_' + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') + '.csv'
  image_filenames = []
  for i in range(0, 94):
    number = filenames[i]
    if thresholding:
      filename = path + str(number[-7:-4]) + ".png"
    else:
      filename = path + str(number[5:8]) + ".png"
    if not os.path.isfile(filename):
        print(filename + " not found")
        continue
    image_filenames.append(filename)
    
  masks_to_submission(submission_filename, *image_filenames)
