# Mapping mask

In [None]:
'''
To create and filter synthetic maps
'''

import numpy as np
import cv2
import os
import os.path as osp

images = "dataset/no_mask"
images_masked = "dataset/mask"
counter = 0 # Counter used for masked images
counter_elimination = 0 # Counter used for number of original images eliminated

'''
Function to create a list of directory entries
'''
def find_path(im):
    try:
        imlist = [osp.join(osp.realpath('.'), im, img) for img in os.listdir(im) if os.path.splitext(img)[1] ==
                  '.png' or os.path.splitext(img)[1] == '.jpeg' or os.path.splitext(img)[1] == '.jpg']
    except NotADirectoryError:
        imlist = []
        imlist.append(osp.join(osp.realpath('.'), im))
    except FileNotFoundError:
        print("No file or directory with the name {}".format(im))
        exit()
    return imlist

'''
Function to check if for each original image 
there is its counterpart masked
'''
def name_check(img, img_mask):
    path1, name1 = os.path.split(img)
    path2, name2 = os.path.split(img_mask)
    name1 = name1[:-4]
    search = name2.find(name1)
    if search > -1:
        return True, name1
    print(name1+" Not found!")
    return False, name1

'''
Function to eliminate intermediate value 
and obtain only black and white values
'''

def binarization(n, map): # n is arbitrary chosen after several tests
    for i in range(0, 256):
        for j in range(0, 256):
            if map[i, j]<n:
                map[i, j] = 0
            else:
                map[i, j] = 255


def filter(image):
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    return cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel) # Filter with erosion and dilation 

imlist = find_path(images)
imlist2 = find_path(images_masked)

for image in imlist:
    image_masked = imlist2[counter]
    flag, name = name_check(image, image_masked)
    if flag: 
        counter += 1
        image = cv2.imread(image)
        #cv2.imshow('originale', image)
        image_masked = cv2.imread(image_masked)
        #cv2.imshow('Masked', image_masked)
        imMap = cv2.absdiff(image_masked, image) # Difference between original image and corresponding masked image
        imMap = cv2.cvtColor(imMap, cv2.COLOR_BGR2GRAY)
        imMap = np.array(imMap)
        binarization(6, imMap)
        #cv2.imshow('Map with noise', imMap)
        openingMap = filter(imMap)
        cv2.imwrite('dataset/maps/'+name+'.jpg', openingMap)
        #cv2.imshow('Erosion and dilation', openingMap)
        #cv2.waitKey(0)
    else:
        os.remove(image)
        print("Removed")
        counter_elimination += 1

print("Elaborate "+str(counter)+" images")
print("Eliminate "+str(counter_elimination)+" images")

# Preparing dataset

In [None]:
'''
Preparation dataset for both modules
'''

import tensorflow as tf
from tensorflow.keras.preprocessing import *


'''
Training directories
'''
dsize = (128, 128)
images = "dataset/mask"
images_map = "dataset/maps"
images_GAN_gt = "dataset/no_mask"

'''
Testing directories
'''
path_test = "testing/test_mask"
path_test_map = "testing/test_map"
path_test_gt = "testing/test_nomask"

def normalize_seg(image): # Normalization between 0 and 1
    image = image / 255.0
    return image

def normalize_GAN(image): # Normalization between -1 and 1
    image = (image / 127.5) - 1
    return image

def prepare_tf_segmentation():
    X_train = image_dataset_from_directory(images, image_size=dsize, color_mode="grayscale", label_mode=None, validation_split=0.2, subset="training", shuffle=True, seed = 1, interpolation="lanczos5")
    X_val = image_dataset_from_directory(images, image_size=dsize, color_mode="grayscale", label_mode=None, validation_split=0.2, subset="validation", shuffle=True, seed = 1, interpolation="lanczos5")
    Y_train = image_dataset_from_directory(images_map, image_size=dsize, color_mode="grayscale", label_mode=None, validation_split=0.2, subset="training",shuffle=True, seed = 1, interpolation="lanczos5")
    Y_val = image_dataset_from_directory(images_map, image_size=dsize, color_mode="grayscale", label_mode=None, validation_split=0.2, subset="validation", shuffle=True, seed = 1, interpolation="lanczos5")
    X_train = X_train.map(normalize_seg)
    Y_train = Y_train.map(normalize_seg)
    dataset = tf.data.Dataset.zip((X_train , Y_train))
    X_val = X_val.map(normalize_seg)
    Y_val = Y_val.map(normalize_seg)
    dataval = tf.data.Dataset.zip((X_val , Y_val))
    return dataset, dataval

def prepare_tf_GAN():
    mask_ds = image_dataset_from_directory(images, image_size=dsize, label_mode=None, validation_split=0.05, subset="training", shuffle=True, seed = 1, interpolation="lanczos5", batch_size=16)
    mask_ds_test = image_dataset_from_directory(images, image_size=dsize, label_mode=None, validation_split=0.05, subset="validation", shuffle=True, seed = 47, interpolation="lanczos5", batch_size=16)
    map_ds = image_dataset_from_directory(images_map, image_size=dsize, color_mode="grayscale", label_mode=None, validation_split=0.05, subset="training",shuffle=True, seed = 1, interpolation="lanczos5", batch_size=16)
    map_ds_test = image_dataset_from_directory(images_map, image_size=dsize, color_mode="grayscale", label_mode=None, validation_split=0.05, subset="validation", shuffle=True, seed = 47, interpolation="lanczos5", batch_size=16)
    gt_ds = image_dataset_from_directory(images_GAN_gt, image_size=dsize, label_mode=None, validation_split=0.05, subset="training",shuffle=True, seed = 1, interpolation="lanczos5", batch_size=16)
    gt_ds_test = image_dataset_from_directory(images_GAN_gt, image_size=dsize, label_mode=None, validation_split=0.05, subset="validation", shuffle=True, seed = 47, interpolation="lanczos5", batch_size=16)
    mask_ds = mask_ds.map(normalize_GAN)
    map_ds = map_ds.map(normalize_GAN)
    mask_ds_test = mask_ds_test.map(normalize_GAN)
    map_ds_test = map_ds_test.map(normalize_GAN)
    gt_ds = gt_ds.map(normalize_GAN)
    gt_ds_test = gt_ds_test.map(normalize_GAN)
    dataset = tf.data.Dataset.zip((mask_ds, map_ds, gt_ds))
    testset = tf.data.Dataset.zip((mask_ds_test, map_ds_test, gt_ds_test))
    return dataset, testset
	
def prepare_tf_testseg():
	mask_seg = image_dataset_from_directory(path_test, image_size=dsize, color_mode='grayscale', label_mode=None, shuffle=False, interpolation="lanczos5", batch_size=16)
	mask_seg = mask_seg.map(normalize_seg)
	imlist = find_path(path_test+"/test_masked")
	names=[]
	for i in imlist:
		i = i.split("/")
		name = i[-1]
		name = name.split('_')
		name = name[0]
		names.append(name)
	return names, mask_seg
	
def prepare_tf_testset():
	map = image_dataset_from_directory(path_test_map, image_size=dsize, color_mode='grayscale', label_mode=None, shuffle=False, interpolation="lanczos5", batch_size=16)
	mask_GAN = image_dataset_from_directory(path_test, image_size=dsize, label_mode=None, shuffle=False, interpolation="lanczos5", batch_size=16)
	# gt_GAN = image_dataset_from_directory(path_test_gt, image_size=dsize, label_mode=None, shuffle=False, interpolation="lanczos5", batch_size=16)
	map = map.map(normalize_GAN)
	mask_GAN = mask_GAN.map(normalize_GAN)
	# gt_GAN = gt_GAN.map(normalize_GAN)
	testset = tf.data.Dataset.zip((mask_GAN,map))
	return testset


# Editing module model

In [None]:
'''
Definition Editing Module
'''

import tensorflow as tf
from tensorflow.keras.utils import *
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *

'''
Preparing inputs for mask region discriminator
Methodology for selecting mask region from original image
'''
def prepare_input_disc_mask(x):
    Igt_Iedit = x[0]
    Imask_map = x[1]
    Iinput = x[2]
    Imask_map = Imask_map/255.0
    complementary = 1-Imask_map
    firstmul = Multiply()([Iinput, complementary])
    secondmul = Multiply()([Igt_Iedit, Imask_map])
    Imask_region = Add()([firstmul, secondmul])
    return Imask_region

'''
Squeeze and Excitation block
'''
def se_block(in_block, ch, ratio=16):
    x = GlobalAveragePooling2D()(in_block)
    x = Dense(ch//ratio, activation='relu')(x)
    x = Dense(ch, activation='sigmoid')(x)
    return Multiply()([in_block, x])

'''
Generator
'''
def generator():
    input_size = (128, 128, 3)
    
    input_mask = Input(input_size)
    input_map = Input((128, 128, 1))
    inputs = concatenate([input_mask, input_map])

    conv1 = Conv2D(64, 3, activation = 'relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    pool1 = Dropout(0.25)(pool1)

    conv2 = Conv2D(128, 3, padding='same')(pool1)
    conv2 = LeakyReLU()(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, 3, padding='same')(conv2)
    conv2 = LeakyReLU()(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    pool2 = Dropout(0.5)(pool2)
    # se = se_block(pool1, ch=1)

    conv3 = Conv2D(256, 3, padding='same', dilation_rate=(2, 2))(pool2)
    conv3 = LeakyReLU()(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, 3, padding='same', dilation_rate=(4, 4))(conv3)
    conv3 = LeakyReLU()(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, 3, padding='same', dilation_rate=(8, 8))(conv3)
    conv3 = LeakyReLU()(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, 3, padding='same', dilation_rate=(16, 16))(conv3)
    conv3 = LeakyReLU()(conv3)
    conv3 = BatchNormalization()(conv3)

    up4 = Conv2DTranspose(128, 3, strides=(2, 2), padding='same')(conv3)
    merge4 = concatenate([conv2, up4])
    conv4 = Conv2D(128, 3, padding='same')(merge4)
    conv4 = LeakyReLU()(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(128, 3, padding='same')(conv4)
    conv4 = LeakyReLU()(conv4)
    conv4 = BatchNormalization()(conv4)

    up5 = Conv2DTranspose(64, 3, strides=(2, 2), padding='same')(conv4)
    merge5 = concatenate([conv1, up5])
    conv5 = Conv2D(64, 3, padding='same')(merge5)
    conv5 = LeakyReLU()(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(64, 3, padding='same')(conv5)
    conv5 = LeakyReLU()(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(3, 3, padding='same')(conv5)
    conv5 = LeakyReLU()(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(3, 1, activation='tanh')(conv5)

    generator = Model(inputs=[input_mask, input_map], outputs=conv5)
    #generator.summary()
    #plot_model(model, show_shapes=True, to_file='unet_model.png')

    return generator

'''
Whole region discriminator
'''
def disc_whole_region():
    initializer = tf.random_normal_initializer(0., 0.02)
    input_size = (128, 128, 3)

    input = Input(input_size)
    conv1 = Conv2D(64, 4, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)(input)
    conv1 = LeakyReLU()(conv1)

    conv2 = Conv2D(128, 4, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)(conv1)
    conv2 = BatchNormalization()(conv2)
    conv2 = LeakyReLU()(conv2)

    conv3 = Conv2D(256, 4, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)(conv2)
    conv3 = BatchNormalization()(conv3)
    conv3 = LeakyReLU()(conv3)

    zero_pad4 = ZeroPadding2D()(conv3)  
    conv4 = Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad4)  
    conv4 = BatchNormalization()(conv4)
    conv4 = LeakyReLU()(conv4)

    zero_pad5 = ZeroPadding2D()(conv4)  
    conv5 = Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad5)  
    discriminator = Model(inputs=input, outputs=conv5)
    return discriminator

'''
Mask region discriminator
'''
def disc_mask_region():
    initializer = tf.random_normal_initializer(0., 0.02)
    input_size = (128, 128, 3)

    Igt_Iedit = Input(input_size) # Ground truth or generated image
    Imask_map = Input((128, 128, 1)) # Mask map image
    Iinput = Input(input_size) # Original image
    input = Lambda(prepare_input_disc_mask)([Igt_Iedit, Imask_map, Iinput]) # preparation inputs

    conv1 = Conv2D(64, 4, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)(input)
    conv1 = LeakyReLU()(conv1)

    conv2 = Conv2D(128, 4, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)(conv1)
    conv2 = BatchNormalization()(conv2)
    conv2 = LeakyReLU()(conv2)

    conv3 = Conv2D(256, 4, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)(conv2)
    conv3 = BatchNormalization()(conv3)
    conv3 = LeakyReLU()(conv3)

    zero_pad4 = ZeroPadding2D()(conv3)  
    conv4 = Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad4)  
    conv4 = BatchNormalization()(conv4)
    conv4 = LeakyReLU()(conv4)

    zero_pad5 = ZeroPadding2D()(conv4)  
    conv5 = Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad5)  

    discriminator = Model(inputs=[Igt_Iedit, Imask_map, Iinput], outputs=conv5)

    #discriminator.summary()
    #plot_model(discriminator, show_shapes=True)
    return discriminator

'''
Perceptual network
'''
def vgg19_model():
    selected_layers = ["block3_conv4", "block4_conv4", "block5_conv4"]
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=(128, 128, 3))
    vgg.trainable = False
    outputs = [vgg.get_layer(l).output for l in selected_layers]
    vgg_model = Model(vgg.input, outputs)
    #vgg_model.summary()
    return vgg_model

# Map module model

In [None]:
'''
Definition Map Module
'''

import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *

def seg_model():
    input_size = (128, 128, 1)

    inputs = Input(input_size)
    conv1 = Conv2D(64, 3, activation = 'relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    pool1 = Dropout(0.25)(pool1)

    conv2 = Conv2D(128, 3, padding='same')(pool1)
    conv2 = LeakyReLU()(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, 3, padding='same')(conv2)
    conv2 = LeakyReLU()(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    pool2 = Dropout(0.5)(pool2)

    conv3 = Conv2D(256, 3, padding='same')(pool2)
    conv3 = LeakyReLU()(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, 3, padding='same')(conv3)
    conv3 = LeakyReLU()(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    pool3 = Dropout(0.5)(pool3)

    conv4 = Conv2D(512, 3, padding='same')(pool3)
    conv4 = LeakyReLU()(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(512, 3, padding='same')(conv4)
    conv4 = LeakyReLU()(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    pool4 = Dropout(0.5)(pool4)

    conv5 = Conv2D(1024, 3, padding='same')(pool4)
    conv5 = LeakyReLU()(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(1024, 3, padding='same')(conv5)
    conv5 = LeakyReLU()(conv5)
    conv5 = BatchNormalization()(conv5)

    up6 = Conv2DTranspose(512, 3, strides=(2, 2), padding='same')(conv5)
    merge6 = concatenate([conv4, up6])
    conv6 = Conv2D(512, 3, padding='same')(merge6)
    conv6 = LeakyReLU()(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(512, 3, padding='same')(conv6)
    conv6 = LeakyReLU()(conv6)
    conv6 = BatchNormalization()(conv6)

    up7 = Conv2DTranspose(256, 3, strides=(2, 2), padding='same')(conv6)
    merge7 = concatenate([conv3, up7])
    conv7 = Conv2D(256, 3, padding='same')(merge7)
    conv7 = LeakyReLU()(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(256, 3, padding='same')(conv7)
    conv7 = LeakyReLU()(conv7)
    conv7 = BatchNormalization()(conv7)

    up8 = Conv2DTranspose(128, 3, strides=(2, 2), padding='same')(conv7)
    merge8 = concatenate([conv2, up8])
    conv8 = Conv2D(128, 3, padding='same')(merge8)
    conv8 = LeakyReLU()(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(128, 3, padding='same')(conv8)
    conv8 = LeakyReLU()(conv8)
    conv8 = BatchNormalization()(conv8)

    up9 = Conv2DTranspose(64, 3, strides=(2, 2), padding='same')(conv8)
    merge9 = concatenate([conv1, up9])
    conv9 = Conv2D(64, 3, activation='tanh', padding='same')(merge9)
    conv9 = Conv2D(64, 3, activation='tanh', padding='same')(conv9)
    conv9 = Conv2D(2, 3, activation='tanh', padding='same')(conv9)
    conv9 = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=conv9)

    model.summary()
    #plot_model(model, show_shapes=True, to_file='unet_model.png')

    model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])

    return model


# Train editing module

In [None]:
'''
Training Editing Module
'''

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import time
import tensorflow.keras.backend as K
from tensorflow.keras.utils import *
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from PIL import Image

'''
Code to avoid OoM
'''
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    except RuntimeError as e:
        print(e)

'''
Variables and optimizers initialization 
Loading models and dataset 
'''
count=1
EPOCHS=100

log_dir="logs/"
summary_writer = tf.summary.create_file_writer(log_dir + "fit/final_loss")

gen = generator()
disc_whole = disc_whole_region()
disc_mask = disc_mask_region()
vgg_model = vgg19_model()
print("Model Created")

train_ds, test_ds = prepare_tf_GAN()
print("Data uploaded")

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
LAMBDA_whole = 0.3
LAMBDA_mask = 0.7
LAMBDA_rc = 100
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_whole_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_mask_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

'''
Definition of perceptual loss
'''
@tf.function
def perceptual_loss(gen_image, gt_image):
    h1_list = vgg_model(gen_image)
    h2_list = vgg_model(gt_image)
    perc_loss = 0.0
    for h1, h2 in zip(h1_list, h2_list):
        h1 = K.batch_flatten(h1)
        h2 = K.batch_flatten(h2)
        perc_loss += K.sum(K.square(h1 - h2), axis=-1)
    perc_loss = tf.reduce_mean(perc_loss)
    return perc_loss

'''
Definition of Non-Saturating loss for discriminator and generator
'''
def disc_loss(disc_real_output, disc_gen_output):
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output) # Real samples
    fake_loss = cross_entropy(tf.zeros_like(disc_gen_output), disc_gen_output) # Fake samples
    total_loss = (real_loss + fake_loss)
    return total_loss

def gen_loss(disc_gen_output):
    adv_loss = cross_entropy(tf.ones_like(disc_gen_output), disc_gen_output) # Adversarial loss
    return adv_loss

'''
Definition of recostrunction loss
'''
def rec_loss(gen_output, Igt):
    l1_loss = tf.reduce_mean(tf.abs(Igt - gen_output)) # L1 loss
    SSIM_loss = 1 - tf.reduce_mean(tf.image.ssim(gen_output, Igt, max_val=2.0)) # SSIM loss
    rc_loss = l1_loss + SSIM_loss
    return rc_loss

'''
First part of training
Training of generator, whole region discriminator and perceptual network
'''
@tf.function
def first_train_cycle(input_image, input_map, Igt, epoch, n):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_image = gen([input_image, input_map], training=True)

        real_output = disc_whole(input_image, training=True)
        fake_output = disc_whole(generated_image, training=True)

        generator_loss = gen_loss(fake_output)
        discriminator_loss = disc_loss(real_output, fake_output)
        rc_loss = rec_loss(generated_image, Igt)
        perc_loss = perceptual_loss(generated_image, Igt)
        gen_tot_loss = LAMBDA_rc*(rc_loss + perc_loss) + generator_loss

    gradients_of_generator = gen_tape.gradient(gen_tot_loss, gen.trainable_variables)
    gradients_of_disc_whole = disc_tape.gradient(discriminator_loss, disc_whole.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
    disc_whole_optimizer.apply_gradients(zip(gradients_of_disc_whole, disc_whole.trainable_variables))

    with summary_writer.as_default(): # Plotting of losses with Tensorboard
        tf.summary.scalar('gan_loss', gen_tot_loss, step=epoch)
        tf.summary.scalar('generator_loss', generator_loss, step=epoch)
        tf.summary.scalar('reconstruction_loss', rc_loss, step=epoch)
        tf.summary.scalar('perceptual_loss', perc_loss, step=epoch)
        tf.summary.scalar('disc_whole_loss', discriminator_loss, step=epoch)

'''
Second part of training
Training of generator, whole region discriminator, mask region discriminator and perceptual network
'''
@tf.function
def second_train_cycle(input_image, input_map, Igt, epoch, n):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_whole_tape, tf.GradientTape() as disc_mask_tape:
        generated_image = gen([input_image, input_map], training=True)

        real_output_whole = disc_whole(input_image, training=True)
        fake_output_whole = disc_whole(generated_image, training=True)
        real_output_mask = disc_mask([Igt, input_map, input_image], training=True)
        fake_output_mask = disc_mask([generated_image, input_map, input_image], training=True)

        gen_loss_whole = gen_loss(fake_output_whole)
        disc_loss_whole = LAMBDA_whole * disc_loss(real_output_whole, fake_output_whole)
        gen_loss_mask = gen_loss(fake_output_mask)
        disc_loss_mask = LAMBDA_mask * disc_loss(real_output_mask, fake_output_mask)
        rc_loss = rec_loss(generated_image, Igt)
        perc_loss = perceptual_loss(generated_image, Igt)
        gen_tot_loss = LAMBDA_rc*(rc_loss + perc_loss) + LAMBDA_whole*(gen_loss_whole) + LAMBDA_mask*(gen_loss_mask)

    gradients_of_generator = gen_tape.gradient(gen_tot_loss, gen.trainable_variables)
    gradients_of_disc_whole = disc_whole_tape.gradient(disc_loss_whole, disc_whole.trainable_variables)
    gradients_of_disc_mask = disc_mask_tape.gradient(disc_loss_mask, disc_mask.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
    disc_whole_optimizer.apply_gradients(zip(gradients_of_disc_whole, disc_whole.trainable_variables))
    disc_mask_optimizer.apply_gradients(zip(gradients_of_disc_mask, disc_mask.trainable_variables))

    with summary_writer.as_default(): # Plotting of losses with Tensorboard
        tf.summary.scalar('gan_loss', gen_tot_loss, step=epoch)
        tf.summary.scalar('generator_loss', (gen_loss_mask+gen_loss_whole), step=epoch)
        tf.summary.scalar('reconstruction_loss', rc_loss, step=epoch)
        tf.summary.scalar('perceptual_loss', perc_loss, step=epoch)
        tf.summary.scalar('disc_whole_loss', disc_loss_whole, step=epoch)
        tf.summary.scalar('disc_mask_loss', disc_loss_mask, step=epoch)

'''
Generation of fake samples starting from testset
Computation of average SSIM and PSNR for each epoch
'''
def generate_images(model, test_input, test_map, tar, epoch):
    prediction = model([test_input, test_map], training=True)
    score_SSIM = tf.image.ssim(prediction, tar, max_val=2.0)
    score_PSNR = tf.image.psnr(prediction, tar, max_val=2.0)
    np_score_PSNR=score_PSNR.numpy()
    average_PSNR = np.average(np_score_PSNR)
    np_score_SSIM=score_SSIM.numpy()
    average_SSIM = np.average(np_score_SSIM)
    with summary_writer.as_default(): # Plotting of metrics with Tensorboard
        tf.summary.scalar('SSIM', average_SSIM, step=epoch)
        tf.summary.scalar('PSNR', average_PSNR, step=epoch)
    # Saving of first image of the batch (for each epoch) with corresponding SSIM and PSNR
    fig=plt.figure(figsize=(15, 7.5))
    fig.text(0.5, 0.15, "SSIM: " + str(np_score_SSIM[0]), fontsize=20, horizontalalignment="center")
    fig.text(0.5, 0.1, "PSNR: " + str(np_score_PSNR[0]) + "dB", fontsize=20, horizontalalignment="center")
    plt.suptitle("Epoch " + str(epoch+1), fontsize=20, ha="center")
    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Test Input', 'Ground Truth', 'Predicted Image']
    global count
    stringa ="result/GAN/final_loss/" + str(count) + ".png"
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i], fontsize=16)
        # Getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    # plt.show()
    fig.savefig(stringa)
    
    # Saving of the image with the best SSIM and PSNR (for each epoch)
    index_SSIM = tf.argmax(score_SSIM)
    index_PSNR = tf.argmax(score_PSNR)
    fig=plt.figure(figsize=(30, 15))
    fig.text(0.5, 0.07, "SSIM: " + str(np_score_SSIM[index_SSIM]), fontsize=28, horizontalalignment="center")
    fig.text(0.5, 0.03, "PSNR: " + str(np_score_PSNR[index_PSNR]) + "dB", fontsize=28, horizontalalignment="center")
    plt.suptitle("Epoch " + str(epoch+1), fontsize=40, ha="center")
    display_list = [test_input[index_SSIM], tar[index_SSIM], prediction[index_SSIM], test_input[index_PSNR], tar[index_PSNR], prediction[index_PSNR]]
    title = ['SSIM_Test Input', 'SSIM_Ground Truth', 'SSIM_Predicted Image','PSNR_Test Input', 'PSNR_Ground Truth', 'PSNR_Predicted Image']
    stringa ="result/GAN/SSIM_PSNR/" + str(count) + ".png"
    count += 1
    for i in range(6):
        plt.subplot(2, 3, i+1)
        plt.title(title[i], fontsize=22)
        # Getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    # plt.show()
    fig.savefig(stringa)

'''
Training
'''
def fit(train_ds, epochs, test_ds):
    first_epochs = int(round(epochs * 0.4)) # Variable to differentiate training
    for epoch in range(0, epochs):
        if epoch < first_epochs:
            print("First training cycle")
        else:
            print("Second training cycle")
        counter = 0
        start = time.time()
        print("Epoch: ", str(epoch+1))   

        for n, (input_image, input_map, target) in train_ds.enumerate():
            counter+=1
            print(str(counter), end=' ', flush=True)
            if epoch < first_epochs:
                first_train_cycle(input_image, input_map, target, epoch, n)
            else:
                second_train_cycle(input_image, input_map, target, epoch, n)
        for example_input, example_map, example_target in test_ds.take(1):
            generate_images(gen, example_input, example_map, example_target, epoch)
        # Saving (checkpoint) the model every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
            print("Checkpoint saved!")

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

    checkpoint.save(file_prefix=checkpoint_prefix)

'''
Checkpoint management
'''
checkpoint_dir = './Checkpoints/GAN/final_loss'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, disc_whole_optimizer=disc_whole_optimizer, disc_mask_optimizer=disc_mask_optimizer, generator=gen, disc_whole=disc_whole, disc_mask=disc_mask)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)
checkpoint.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

'''
Start training
'''
fit(train_ds, EPOCHS, test_ds)

# Train mapping module

In [None]:
'''
Training Map Module
'''

import tensorflow as tf
import matplotlib.pyplot as plt

'''
Code to avoid OoM
'''
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    except RuntimeError as e:
        print(e)

'''
Loading dataset and training
'''
model = seg_model()
dataset, dataval = prepare_tf_segmentation()
print("Uploaded the data")

history = model.fit(dataset, validation_data=dataval, epochs=20)
model.save('segmentation_model_20k_20epoch.h5')

'''
Plotting Loss and Accuracy
'''
print(history.history.keys())
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()


# Testing

In [None]:
'''
Testing whole architecture
'''

import tensorflow as tf
from tensorflow.keras.utils import *
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import matplotlib.pyplot as plt
import os
from tensorflow.keras.preprocessing import *
import os.path as osp
import numpy as np

'''
Code to avoid OoM
'''
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    except RuntimeError as e:
        print(e)

'''
Function to eliminate intermediate values
and obtain only black and white values
'''
def binarization(n, map):
    for i in range(0, 128):
        for j in range(0, 128):
            if map[i, j]<n:
                map[i, j] = 0
            else:
                map[i, j] = 1

'''
Segmentation
'''
def map_generation(images, names):
    prediction = seg_model.predict(x = images)
    count = 0
    stringa = "testing/test_map/test/"
    for image in prediction:
        binarization(0.01, image)
        image = filter(image)
        image = image.reshape(128,128,1)
        tf.keras.preprocessing.image.save_img(stringa+names[count]+'.jpg', image)
        count+=1  

'''
Generation synthetic unmasked images
'''
def generate_image(model, test_input, test_map, names):
    global count
    prediction = model([test_input, test_map], training=True)
    fig=plt.figure(figsize=(15, 7.5))
    plt.suptitle("TEST", fontsize=24, ha="center")
    display_list = [test_input[0, :, :, :], test_map[0, :, :, :], prediction[0, :, :, :]]
    title = ['Test Input', 'Segmentation Map', 'Predicted Image']
    stringa ="testing/results/"
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i], fontsize=18)
        # Getting the pixel values between [0, 1] to plot it.
        if i == 1:
            plt.imshow(display_list[i], cmap = 'gray')
        else:
            plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    tf.keras.preprocessing.image.save_img(stringa+names[count]+'.jpg', prediction[0, :, :, :])
    # fig.savefig(stringa+names[count]+'.jpg')
    plt.close('all')
    count+=1

def compute_metrics(model, test_input, test_map, tar):
    prediction = model([test_input, test_map], training=True)
    score_SSIM = tf.image.ssim(prediction, tar, max_val=2.0)
    score_PSNR = tf.image.psnr(prediction, tar, max_val=2.0)
    np_score_PSNR=score_PSNR.numpy()
    average_PSNR = np.average(np_score_PSNR)
    np_score_SSIM=score_SSIM.numpy()
    average_SSIM = np.average(np_score_SSIM)
    return average_PSNR, average_SSIM

'''
Checkpoint and models restore
'''
seg_model = load_model("Checkpoints/Segmentation/segmentation_model_20k_20epoch.h5")
gen = generator()
checkpoint_dir = "Checkpoints/GAN/final_loss"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator=gen)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)
checkpoint.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

count = 0
names, mask_seg = prepare_tf_testseg()
map_generation(mask_seg, names)
print("Segmentation of masks: DONE")
testset = prepare_tf_testset()
for n, (example_input, example_map) in testset.enumerate():
    generate_image(gen, example_input, example_map, names)
print("Generation of synthetical images: DONE")
