In [None]:
import os
import cv2
import time
# import wandb
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

import tensorflow as tf
from tensorflow import keras

from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers import Add, LeakyReLU
from keras import Model

# **Handle Data**

In [None]:

def load(target_path, input_path):
  # Read and decode an image file to a uint8 tensor
  real_image = tf.io.read_file((target_path))
  real_images = tf.image.decode_jpeg(real_image, channels= 3)
  input_image = tf.io.read_file((input_path))
  input_images = tf.image.decode_jpeg(input_image, channels= 3)

  # Convert both images to float32 tensors
 
  # input_image, real_image = resize(input_image, real_image, height=256, width=256)
  return input_images, real_images


def resize(input_image, real_image, height=256, width=256):
  input_image = tf.image.resize(input_image, (height, width), method= 'bilinear')
  real_image = tf.image.resize(real_image, (height, width), method= 'bilinear')
  return input_image, real_image


def normalize(input_image, real_image):
  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

def load_images_train(target_path, input_path):
  input_image, real_image = load(target_path, input_path)
#   input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)
  input_image, real_image = resize(input_image, real_image)
  return input_image, real_image


def load_images_test(target_path, input_path):
  input_image, real_image = load(target_path, input_path)
  input_image, real_image = normalize(input_image, real_image)
  input_image, real_image = resize(input_image, real_image)
  return input_image, real_image


def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15, 15))

  display_list = [test_input[0]*0.5 + 0.5, tar[0]*0.5 + 0.5, prediction[0]*0.5 + 0.5]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # Getting the pixel values in the [0, 1] range to plot.
    plt.imshow(display_list[i])
    plt.axis('off')
  plt.show()

# **Data Loader**

In [None]:
BATCH_SIZE = 16
BUFFER_SIZE = 500

target_path_train= r"" 
input_path_train= r""
target_path_test= r""
input_path_test=  r""

train_image = []
train_target = []
tu = False
for name_target in os.listdir(target_path_train):
  if tu:
    break
  for name_image in os.listdir(input_path_train):
    if name_image == name_target.split('.')[0].split('_')[0] + '.png':
      train_image.append(input_path_train + '/{}'.format(name_image))
      train_target.append(target_path_train + '/{}'.format(name_target))
#       if len(train_target)==200:
#         tu=True
#         break

test_image = []
test_target = []
op = False
for name_target in os.listdir(target_path_test):
  if op:
    break
  for name_image in os.listdir(input_path_test):
    if name_image == name_target.split('.')[0].split('_')[0] + '.png':
      test_image.append(input_path_test + '/{}'.format(name_image))
      test_target.append(target_path_test + '/{}'.format(name_target))
#       if len(test_target)==1:
#         op= True
#         break

train_datasets = tf.data.Dataset.from_tensor_slices((train_image, train_target )).map(load_images_train, num_parallel_calls= tf.data.experimental.AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_datasets = tf.data.Dataset.from_tensor_slices((test_image, test_target )).map(load_images_test, num_parallel_calls= tf.data.experimental.AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# **Check Data**

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize= (20, 20))
for image, mask in train_datasets.take(1):
  plt.subplot(1, 2, 1)
  plt.imshow(image[0]*0.5 + 0.5)

  plt.subplot(1, 2, 2)
  plt.imshow(mask[0]*0.5 + 0.5)

# **Generator**





In [None]:
class spatial_attention(tf.keras.layers.Layer):
    """ spatial attention module 
        
    Contains the implementation of Convolutional Block Attention Module(CBAM) block.
    As described in https://arxiv.org/abs/1807.06521.
    """
    def __init__(self, kernel_size=7, trainable=True ,**kwargs):
        self.kernel_size = kernel_size
        super(spatial_attention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.conv3d = tf.keras.layers.Conv2D(filters=1, 
                                             kernel_size=self.kernel_size,
                                             strides=1, 
                                             padding='same', 
                                             activation='sigmoid',
                                             kernel_initializer='he_normal', 
                                             use_bias=False)
        super(spatial_attention, self).build(input_shape)

    def compute_output_shape(self, input_shape):
        return input_shape

    def call(self, inputs):
        avg_pool = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=-1, keepdims=True))(inputs)
        max_pool = tf.keras.layers.Lambda(lambda x: tf.keras.backend.max(x, axis=-1, keepdims=True))(inputs)
        concat = tf.keras.layers.Concatenate(axis=-1)([avg_pool, max_pool])
        feature = self.conv3d(concat)	
        multiplied = tf.keras.layers.Multiply()([inputs, feature])
        # shape_out = multiplied.shape   
        # return tf.keras.layers.multiply()([inputs, feature])
        return multiplied
    

class Attention_map(tf.keras.layers.Layer):
    """ 
    Attention module
    As described in: https://arxiv.org/pdf/2112.01098.pdf
    """ 
    def __init__(self,num_filters, trainable=True, **kwargs):
        self.num_filters = num_filters
        self.initializer = tf.random_normal_initializer(0., 0.02)

        self.conv4a = tf.keras.layers.Conv2D(4*self.num_filters, 
                                            kernel_size=3,
                                            padding="same")
        
        self.conv8 = tf.keras.layers.Conv2D(8*self.num_filters, 
                                            kernel_size=3,
                                            padding="same")
        
        self.conv4b = tf.keras.layers.Conv2D(4*self.num_filters, 
                                            kernel_size=3,
                                            padding="same")
        
        self.conv2 = tf.keras.layers.Conv2D(2*self.num_filters, 
                                            kernel_size=3,
                                            padding="same") 
        
        super(Attention_map,self).__init__(**kwargs)
    
    def call(self, inputs, training, **kwargs):
        x = self.conv4a(inputs)
        x = self.conv8(x)
        x = self.conv4b(x)
        x = self.conv2(x)
        return x
    

def compute_fused(fenc, fdec, name_layer, trainable=True):
    """
    Compute fused module \n
    f_fused = fenc*Attention_map[0] + fdec*Attention_map[1]
    """
    fconcat =  tf.concat([fenc, fdec], 0)
    output_attentionmap = Attention_map(num_filters=64,trainable=trainable, name=name_layer)(fconcat)
    f_fused = fenc*output_attentionmap[0] + fdec*output_attentionmap[1]
    return f_fused

In [None]:
class Residual_Block(keras.layers.Layer):
    """ 
    Residual Block from ResNet:
    Input: (2H,2W,C) => Ouput: (H,W,C')
    """
    def __init__(self, num_filter, kernel_size=3, trainable=True,  **kwargs):
        super(Residual_Block, self).__init__(**kwargs)
        self.filters= num_filter
        self.kernel_size = kernel_size
        self.trainable = trainable

    def build(self, input_shape):
        self.x_skip = tf.keras.layers.Conv2D(self.filters, 
                                             kernel_size=1, 
                                             padding="same")

        self.conv2a = tf.keras.layers.Conv2D(self.filters, 
                                             kernel_size=self.kernel_size, 
                                             padding='same')
        self.bn2a = tf.keras.layers.BatchNormalization()

        self.conv2b = tf.keras.layers.Conv2D(filters=self.filters, 
                                             kernel_size=self.kernel_size,
                                             padding='same')
        self.bn2b = tf.keras.layers.BatchNormalization()
        super(Residual_Block, self).build(input_shape)

    def compute_output_shape(self, input_shape):
        return input_shape
        
    def call(self, inputs):
        x_skip = self.conv2a(inputs)

        x = self.conv2a(inputs)
        x = self.bn2a(x,self.trainable)
        out1 = tf.nn.relu(x)

        out2 = self.conv2b(out1)
        out2 = self.bn2b(out2,self.trainable)
        out2 = tf.nn.relu(out2)


        add = tf.keras.layers.Add()([x_skip, out2])
        out = tf.nn.relu(add)
        return out
    

class Decoder_Block(tf.keras.layers.Layer):
    """ 
    Decoder Block: Upsampling and Residual Block 
    Input: (H,W, 2C) => Ouput: (2H,2W, C)
    """
    def __init__(self,num_filter, kernel_size=3, trainable=True, **kwargs):
        super(Decoder_Block, self).__init__(**kwargs)
        self.num_filter = num_filter
        self.kernel_size = kernel_size
        self.residual = Residual_Block(num_filter=self.num_filter, 
                                       kernel_size=self.kernel_size)
        self.upsampling = tf.keras.layers.Conv2DTranspose(filters=self.num_filter,
                                                          kernel_size=2,
                                                          strides=2)
    def call(self, inputs, **kwargs):
        x = self.upsampling(inputs)
        x = self.residual(x)
        return x
  
    
class Ouput_layer(tf.keras.layers.Layer):
    """
    Output layer of model with shape (H,W,3)
    """
    def __init__(self, trainable=True, **kwargs):
        super(Ouput_layer, self).__init__(**kwargs)
        self.conv2d = tf.keras.layers.Conv2D(filters=3,
                                             kernel_size=3,
                                             padding="same")
        self.bn2d = tf.keras.layers.BatchNormalization()
    
    def call(self, inputs, **kwargs):
        output = self.conv2d(inputs)
        output = self.bn2d(output)
        output = tf.nn.relu(output)
        return output

In [None]:
def Generator_model(shape=(256,256,3), train_attention=True):
    x_input = keras.Input(shape, name="Input_layer")

    encoder0 = Residual_Block(num_filter=32, kernel_size=3, name="Encoder_0")(x_input)
    encoder0 = MaxPooling2D((2,2), name="MaxPooling2D_0")(encoder0)

    encoder1 = Residual_Block(num_filter=64, kernel_size=3, name="Encoder_1")(encoder0)
    encoder1 = MaxPooling2D((2,2), name="MaxPooling2D_1")(encoder1)
    
    encoder2 = Residual_Block(num_filter=128, kernel_size=3, name="Encoder_2")(encoder1)

    #FENC
    f_enc = spatial_attention(trainable=train_attention, name="F_enc")(encoder2)

    encoder2 = MaxPooling2D((2,2), name="MaxPooling2D_2")(encoder2)

    encoder3 = Residual_Block(num_filter=256, kernel_size=3, name="Encoder_3")(encoder2)
    encoder3 = MaxPooling2D((2,2), name="MaxPooling2D_3")(encoder3)

    encoder4 = Residual_Block(num_filter=512, kernel_size=3, name="Encoder_4")(encoder3)
    encoder4 = MaxPooling2D((2,2), name="MaxPooling2D_4")(encoder4)

    bottleneck = Conv2D(1024, 3, padding="same", name="bottleneck")(encoder4)

    decoder0 = Decoder_Block(num_filter=512, kernel_size=3, name="Decoder_0")(bottleneck)
    decoder1 = Decoder_Block(num_filter=256, kernel_size=3, name="Decoder_1")(decoder0)
    decoder2 = Decoder_Block(num_filter=128, kernel_size=3, name="Decoder_2")(decoder1)

    #DEC 
    f_dec = spatial_attention(trainable=train_attention, name="F_dec")(decoder2)
    f_fused = compute_fused(f_enc,f_dec, trainable=train_attention, name_layer="Attention_map")

    # if(train_attention==True):
    #     input_decoder3 = Add(name="Add_f_fused")([decoder2, f_fused])
    # elif(train_attention==False):
    #     input_decoder3=decoder2

    input_decoder3 = Add(name="Add_f_fused")([decoder2, f_fused])

    decoder3 = Decoder_Block(num_filter=64, kernel_size=3, name="Decoder_3")(input_decoder3)
    decoder4 = Decoder_Block(num_filter=32, kernel_size=3, name="Decoder_4")(decoder3)

    imageout = Ouput_layer(name="Ouput_layer")(decoder4)

    mymodel = Model(inputs = x_input, outputs = imageout, name = "Generator")
    return mymodel 

generator = Generator_model(train_attention=True)
generator.summary()

# **Discriminator**

In [None]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (e, 256, 256, channels*2)batch_siz

  down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
  down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
  down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

discriminator = Discriminator()
discriminator.summary()

# **Config Loss**

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1)
mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)

def discriminator_loss(real_image, fake_image):
    real_loss = cross_entropy(tf.ones_like(real_image), real_image)
    fake_loss = cross_entropy(tf.zeros_like(fake_image), fake_image)
    total_loss = real_loss + fake_loss
    return total_loss

In [None]:
def loss_SSIM(target_image, gen_image):
    ssim = tf.image.ssim(target_image, gen_image, max_val=1.0, filter_size=11,
                          filter_sigma=1.5, k1=0.01, k2=0.03)
    return (1-ssim)


def generator_loss(target_image, gen_image):
    #Loss SSIM
    l_ssim = tf.reduce_mean(loss_SSIM(target_image, gen_image))

    #Loss Recontruction
    #l_rec = tf.reduce_mean(tf.abs(target_image - gen_image))
    l_rec = mae(target_image, gen_image)
    return l_ssim, l_rec

# **Train Step**

In [None]:
!mkdir path_to_save_checkpoint

In [None]:
#Config optimizers and learning rate
generator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(9e-6, beta_1=0.5)

lambda_rec = 1.2
lambda_adv = 0.5
lambda_ssim = 80
lambda_mask = 1

In [None]:
checkpoint_dir = path_to_checkpoint
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
def evaluate_val(valid_dataset):
    loss_total_each_batch=[]
    loss_ssim_each_batch=[]
    loss_rec_each_batch=[]
    disc_loss_each_batch=[]
    for step, (batch_test_image, batch_test_target) in enumerate(valid_dataset):
        valid_out = generator(batch_test_image, training=True)
        val_ssim_loss,  val_rec_loss = generator_loss(batch_test_target, valid_out)
        loss_ssim_each_batch.append(val_ssim_loss)
        loss_rec_each_batch.append(val_rec_loss)

        disc_real_output  = discriminator([batch_test_image, batch_test_target], training=True)
        disc_generated_output  = discriminator([batch_test_image, valid_out], training=True)

        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
        loss_total = lambda_rec*val_rec_loss + lambda_adv*disc_loss + lambda_ssim*val_ssim_loss

        loss_total_each_batch.append(loss_total)
        disc_loss_each_batch.append(disc_loss)
        
    return np.mean(loss_total_each_batch), np.mean(loss_ssim_each_batch), np.mean(loss_rec_each_batch), np.mean(disc_loss_each_batch)


def train_step(train_datasets, update_D):
    loss_total_each_batch=[]
    loss_ssim_each_batch=[]
    loss_rec_each_batch=[]
    disc_loss_each_batch=[]

    for step, (batch_train_image, batch_train_target) in enumerate(train_datasets):
        print(f"Batch: {step+1}...", end='\r')
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_out = generator(batch_train_image, training=True)
            gen_ssim_loss,  gen_rec_loss = generator_loss(batch_train_target, gen_out)
            loss_ssim_each_batch.append(gen_ssim_loss)
            loss_rec_each_batch.append(gen_rec_loss)

            disc_real_output = discriminator([batch_train_image, batch_train_target], training=True)
            disc_generated_output = discriminator([batch_train_image, gen_out], training=True)

            disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
            loss_total = lambda_rec*gen_rec_loss + lambda_adv*disc_loss + lambda_ssim*gen_ssim_loss

            loss_total_each_batch.append(loss_total)
            disc_loss_each_batch.append(disc_loss)

        generator_gradients = gen_tape.gradient(loss_total,
                                    generator.trainable_variables)
        generator_optimizer.apply_gradients(zip(generator_gradients,
                                        generator.trainable_variables))
    
        if update_D:
            discriminator_gradients = disc_tape.gradient(disc_loss,
                                                discriminator.trainable_variables)
            discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                discriminator.trainable_variables))
            
    return np.mean(loss_total_each_batch), np.mean(loss_ssim_each_batch), np.mean(loss_rec_each_batch), np.mean(disc_loss_each_batch)

In [None]:
def model_fit(train_datasets, test_datasets, epochs, num_epoch_update_D):
    history_train = {"loss_total": [],
                      "loss_ssim": [],
                      "loss_rec": [], 
                      "loss_disc":[]}
    
    history_val = {"loss_total": [],
                   "loss_ssim": [],
                   "loss_rec": [], 
                   "loss_disc":[]}
    count=0
    for epoch in range(epochs):
        start_time = time.time()
        
        #Training
        if epoch%num_epoch_update_D==0:
            update_D = True
            
        if epoch==0:
            update_D=False
            print(f"\nEpoch {epoch+1}/{epochs}\nOny Update G!")

        else:
            if epoch%num_epoch_update_D==0:
                update_D = True
                print(f"\nEpoch {epoch+1}/{epochs}:\nUpdate D!")

            else:
                update_D = False
                print(f"\nEpoch {epoch+1}/{epochs}\nOny Update G!")

        train_loss_total, train_ssim, train_rec, train_disc = train_step(train_datasets, epoch, update_D)

        history_train["loss_total"].append(train_loss_total) 
        history_train["loss_ssim"].append(train_ssim)
        history_train["loss_rec"].append(train_rec) 
        history_train["loss_disc"].append(train_disc) 

        #Evaluation on testsets
        val_total, val_ssim, val_rec, val_disc = evaluate_val(test_datasets)
        
        history_val["loss_total"].append(val_total) 
        history_val["loss_ssim"].append(val_ssim)
        history_val["loss_rec"].append(val_rec)
        history_val["loss_disc"].append(val_disc)

        print(f"  loss_total: {train_loss_total} - loss_ssim: {train_ssim} - loss_rec: {train_rec} - loss_discriminator: {train_disc}")
        print(f"  val_loss: {val_total} - val_ssim: {val_ssim} - val_rec: {val_rec} - val_Disc: {val_disc}")
        print("  Time taken: %.2fs" % (time.time() - start_time))
        
        if history_val["loss_total"][epoch] <= min(history_val["loss_total"]) :
            checkpoint.save(file_prefix=checkpoint_prefix) 
            
        if epoch%2==0:
            for input, target in test_datasets.take(1):
                generate_images(generator, input, target)
            
    plt.plot(history_train['loss_total'])
    plt.plot(history_val['loss_total'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

# **Train model**

In [None]:
epochs = 150
num_epoch_update_D = 15
model_fit(train_datasets, test_datasets, epochs, num_epoch_update_D)

## Test model

generator.save('generator.h5')

In [None]:
status = generator.load_weights("generator.h5", by_name=True)
for input, target in test_datasets.take(1):
    generate_images(generator, input, target)