In [None]:
import cv2 as cv
import numpy as np
from keras.models import Model, load_model
from scipy.stats import entropy
from six.moves import xrange
from matplotlib import pyplot as plt
import tensorflow as tf

# Function that generates random points in the latent space of the generator network
def generate_latent_points(latent_dim, n_samples):
    z_input = np.random.uniform(-1, 1, size=(n_samples , latent_dim))
    return z_input

# Function that masks out pixels on a specified direction of an image
def masked_out_image(input_img, dir_):
    mask = np.ones((64,128,3))
    if dir_=='l':
        mask[:,64:,:] = 0.0
    elif dir_=='r':
        mask[:,:64,:] = 0.0
    masked_out = np.multiply(input_img, mask)
    return masked_out

# Function that calculates the Kullback-Leibler divergence between two grayscale images
def kl_divergence(img1, img2):
    
    # Convert input images to grayscale
    gray_img1 = cv.cvtColor(img1, cv.COLOR_BGR2GRAY)
    gray_img2 = cv.cvtColor(img2, cv.COLOR_BGR2GRAY)

    # Flatten the grayscale images and calculate their histograms
    flat_img1 = gray_img1.flatten()
    flat_img2 = gray_img2.flatten()
    histogram_1, _ = np.histogram(flat_img1, bins=255)
    histogram_2, _ = np.histogram(flat_img2, bins=255)

    # Normalize the histograms and calculate the Kullback-Leibler divergence
    p =( histogram_1 - histogram_1.min()) / (np.sum(histogram_1 - histogram_1.min()))
    q = ( histogram_2 - histogram_2.min()) / (np.sum(histogram_2 - histogram_2.min()))
    return tf.reduce_sum(tf.where(p == 0, tf.zeros_like(p), p * tf.math.log(p /(q+1e-8))))

# Function that calculates the Kullback-Leibler divergence between two probability distributions
def kl(p, q):
    p_shape = tf.shape(p)
    q_shape = tf.shape(q)
    tf.assert_equal(p_shape, q_shape)
    p_ = tf.divide(p, tf.tile(tf.reduce_sum(p), [1, p_shape]))
    q_ = tf.divide(q, tf.tile(tf.reduce_sum(q), [1, p_shape]))
    return tf.reduce_sum(tf.multiply(p_, tf.math.log(tf.divide(p_, q_))), axis=1)

# Function that refines a set of initial points in the latent space to generate an image that matches the input sketch
def refine_init(input_sketch, num_samples = 1):
    # Extract the left part of the input sketch
    input_sketch = input_sketch[:,:64,:]
    
    # Generate initial random points in the latent space
    zhats_init = generate_latent_points(100,1)
    zhats_ = zhats_init.copy()
    kl_div = np.full(len(zhats_), np.inf)
    kld_avg = 0
    z=[]
    for j in range(num_samples):
        for i in xrange(100):
            G_imgs = generator.predict(zhats_)
            kl_d =  kl_divergence(input_sketch, G_imgs[0][:,:64,:])
            if (kl_d < kl_div):
                zhats_init = zhats_
                kl_div = kl_d
            zhats_ = generate_latent_points(100,1)
        z.append(zhats_init)
    return z

def smoothImg(img, j):
    width, height = img.shape[1], img.shape[0]
    new_width, new_height = 2 * width, 2 * height
    img = cv.resize(img, (new_width, new_height))
    
    # Smooth the image
    img = cv.GaussianBlur(img, (5, 5), 0)
    
    # Save the image
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    cv.imwrite(f"Final results/smooth_male_image_{j}.jpg", (img*127.5+127.5).astype(np.uint8))

# Function that completes a set of points in the latent space to generate an image that matches the input image
def completeFun(zhats_init, input_):
    # Convert the initial points to a tensor
    input_vector = tf.convert_to_tensor(zhats_init)
    v = 0
    losses = []
    for i in range(100):
        with tf.GradientTape() as tape:
            tape.watch(input_vector)
            # Generate an image from the current set of points
            generated_data = generator(input_vector)
            
            # Calculate the perceptual loss based on the discriminator's output on the generated image
            perceptual_loss = tf.math.log(1 - (discriminator(generated_data)))
            
            # Calculate the contextual loss based on the Kullback-Leibler divergence between the generated image and the input image
            in_ = input_[:,:64,:]
            out_ = generated_data[0][:,:64,:].numpy()
            contextual_loss = kl_divergence(in_, out_)
            contextual_loss = tf.cast(contextual_loss, dtype=tf.float32)
            
            # Combine the perceptual and contextual losses
            loss = 0.6*contextual_loss + 0.4*perceptual_loss
        losses.append(loss)
        
        # Calculate gradients and update the set of points using the Adam optimizer
        gradients = tape.gradient(tf.convert_to_tensor(loss), input_vector)
        v_prev = np.copy(v)
#         v = 0.9*v + 0.001*gradients
#         input_vector =  input_vector - v
        input_vector = input_vector - 0.001*gradients
        input_vector = np.clip(input_vector, -1, 1)
        input_vector = tf.convert_to_tensor(input_vector)

        if i % 10 == 0:
            print(f"Iteration {i}: Loss = {loss.numpy()}")

    # Convert the final set of points to a numpy array and generate an image from it
    optimized_input_vector = input_vector.numpy()
    generated_data = generator(optimized_input_vector)
    
    # Calculate the perceptual and contextual losses for the final image
    perceptual_loss = tf.math.log(1 - (discriminator(generated_data)))
    in_ = input_[:,:64,:]
    out_ = generated_data[0][:,:64,:].numpy()
    contextual_loss = kl_divergence(in_, out_)
    contextual_loss = tf.cast(contextual_loss, dtype=tf.float32)
    loss = 0.6*contextual_loss + 0.4*perceptual_loss
    print("loss after completion phase ", loss)
    
    # Return the optimized set of points and the list of losses during optimization
    return optimized_input_vector, losses

# Load the generator and discriminator models
generator =  load_model(r'Models\g_model_10,15.h5')
discriminator =  load_model(r'Models\d_model_10,15.h5')

# Iterate over images in a directory and generate completed images for each one
import os
data = os.listdir(r"handwrite")
for i in range(len(data)):
    # Load the input image and display it
    pic=cv.imread(os.path.join(r"handwrite",data[i]))
    input_1 = cv.cvtColor(pic, cv.COLOR_BGR2RGB)
    plt.imshow((input_1).astype(np.uint8),cmap='gray')
    plt.show()
    
    # Refine initial points and generate completed images for each refinement
    zhats_init = refine_init(pic, num_samples = 1)
    plt.figure(figsize=(30, 30))
    for j in range(1):
        
        optimized_input_vector, losses = completeFun(zhats_init[j], pic)
        
        generated_data = generator(optimized_input_vector)
        smoothImg(generated_data[0].numpy(), data[i])
        masked_img = masked_out_image(generated_data[0], 'r')

        ax = plt.subplot(2, 2, j + 1)
        plt.imshow((masked_img*127.5+127.5).astype(np.uint8),cmap='gray')
        plt.axis("off")

    plt.savefig(f'Final results/completed_image_for_{data[i]}.png')
    plt.show()
    print('================')

