In [None]:
import os
import glob
import time
import cv2


import numpy as np
import pandas as pd

import imageio
import matplotlib.pyplot as plt
from IPython.display import display, Image

import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.preprocessing import image
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.layers import (Dense, Conv2DTranspose, Conv2D,
                                     BatchNormalization,MaxPool2D,AveragePooling2D,
                                     LeakyReLU, Dropout, Reshape, Flatten,
                                     ELU,Activation,Input,UpSampling2D)

from sklearn.model_selection import train_test_split
from functools import partial
from tqdm import tqdm
from typing import List , Union , Tuple

In [None]:
def load_img(path : str ,
             width : int = 1200) -> tf.Tensor:
    """ loading an image with resize into shape of (1200,1200) """
    img = cv2.imread(path).astype(np.float32)
    
    img = tf.constant(img/127.5 - 1.) 
    shape_dst = int(np.min(img.shape[:2])*0.8)
    oh = (img.shape[0] - shape_dst) // 2
    ow = (img.shape[1] - shape_dst) // 2
    new_size = (width , width)

    img = tf.expand_dims(img[oh:oh + shape_dst, ow:ow + shape_dst] , axis = 0)
    img = tf.image.resize(img, new_size, method = tf.image.ResizeMethod.BILINEAR)
    return img[0]

def load_all_img(ImageNamePath : List[str]):
    for idx in range(len(ImageNamePath)):
        support = load_img(ImageNamePath[idx])
        yield support
        
        
def rnd_crop_img(img : tf.Tensor ,
                 win_size : int = 256) -> tf.Tensor :
    """ Random Cropping Image into size (256 , 256) """
    CH = 3 # color_channels
    new_size = (win_size , win_size , CH)
    return tf.image.random_crop(img , new_size)


def gen_task(img_loader : tf.data.Dataset , 
             win_size : int = 256) -> (tf.Tensor , tf.Tensor):
    ''' Creating Dataset '''
    for img in img_loader.skip(64).take(1).repeat(BATCH_SIZE*20):
        '''Support processing'''
        support_= rnd_crop_img(img,win_size)
        word = tf.convert_to_tensor(word_embedding['WordEmbedding'][64])
        yield (support_,word)

In [None]:
class Generator(Model):
    def __init__(self, z_dim ):
        super(Generator, self).__init__()

        self.model = Sequential()

        # [z_dim] => [8,8,64]
        self.model.add(Dense(8*8*64,
                             activation = tf.nn.leaky_relu ,
                             use_bias = False , 
                             input_shape = (z_dim ,)))
        self.model.add(Reshape((8, 8,64)))

        # [8,8,64] => [128,128,64]
        for ii in range(int(np.log2(128)-np.log2(8))):
            self.model.add(UpSampling2D(size = (2, 2), 
                                        interpolation = 'bilinear'))
            self.model.add(Conv2D(filters = 64 ,
                                  kernel_size = 3 ,
                                  padding = 'same' ,
                                  activation = tf.nn.leaky_relu ,
                                  use_bias = False))
        
        # [128,128,64] => [256,256,3]
        self.model.add(UpSampling2D(size = (2, 2) ,
                                    interpolation = 'bilinear'))
        self.model.add(Conv2D(filters = CH ,
                              kernel_size = 5 ,
                              padding = 'same' , 
                              activation = 'tanh' ,
                              use_bias = False))

    def call(self, x):
        return self.model(x)

class Discriminator(Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        # [256,256,3] => [128,128, 64]
        self.conv0 = Conv2D(filters = 64 ,
                            kernel_size = 5 ,
                            padding = "same" , 
                            activation = tf.nn.leaky_relu ,
                            input_shape = (W, W, CH))
        
        self.conv1dx = [Conv2D(filters = 64 ,
                               kernel_size = 1 ,
                               padding = "same" , 
                               activation = tf.nn.leaky_relu) for _ in range(4)]
        self.convx = [Conv2D(filters = 64 ,
                             kernel_size = 3,
                             padding = "same" ,
                             activation = tf.nn.leaky_relu ,
                             use_bias = False) for _ in range(4)]
        # [8,8, 64] => [1]
        self.dense = Dense(1)

    def call(self, x):
        xc = self.conv0(x)
        xl = MaxPool2D(pool_size = (2, 2))(xc)
        xr = AveragePooling2D(pool_size = (2, 2))(x)
        x1 = tf.concat((xl,xr),axis = -1)
        for ii in range(4):
            x1 = self.conv1dx[ii](x1)
            xc = self.convx[ii](x1)
            xl = MaxPool2D(pool_size = (2, 2))(xc)
            xr = AveragePooling2D(pool_size = (2, 2))(x1)
            x1 = tf.concat((xl,xr),axis = -1)
        outs = self.dense(Flatten()(x1))
        return outs

In [None]:
def gradient_penalty(real_images : tf.Tensor , 
                     fake_images : tf.Tensor) -> tf.Tensor :
    '''
    @params : BATCH_SIZE : int 
    @params : W : int , width of image
    @params : CH : int = 3 , color_channels
    '''
    epsilon = tf.random.uniform(shape = [BATCH_SIZE, W, W, CH], minval = 0.0, maxval = 1.0)
    x_hat = epsilon * real_images + (1 - epsilon) * fake_images

    with tf.GradientTape() as t:
        t.watch(x_hat)
        d_hat = discriminator(x_hat)
    gradients = t.gradient(d_hat, x_hat)

    g_norm = tf.sqrt(tf.reduce_sum(gradients ** 2, axis = [1, 2]))
    gradient_penalty = tf.reduce_mean((g_norm - 1.0) ** 2)
    return gradient_penalty


def gan_loss(d_real_output, d_fake_output, train_loader, fake_images):
    """
    d_loss -> loss of discriminator 
    g_loss -> loss of generator
    """
    d_loss = tf.reduce_mean(d_fake_output) - tf.reduce_mean(d_real_output) + \
             gradient_penalty(train_loader[0], fake_images) * gp_weight
    g_loss = tf.reduce_mean(-d_fake_output)
    return d_loss, g_loss


@tf.function
def train_step(train_loader, generator, discriminator, g_optimizer, d_optimizer):
    noise = train_loader
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        fake_images = generator(noise[1], training=True)
        d_real_logits = discriminator(noise[0])
        d_fake_logits = discriminator(fake_images)

        d_loss, g_loss = gan_loss(d_real_logits, d_fake_logits, train_loader, fake_images)

    g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)
    d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)

    g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))
    d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))
    return d_loss, g_loss


def train(train_loader, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in train_loader:
            d_loss, g_loss = train_step(image_batch, generator, discriminator, g_optimizer, d_optimizer)

        # Produce images
        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
        print('discriminator loss: %.5f' % d_loss)
        print('generator loss: %.5f' % g_loss)
        generate_and_save_images(generator, epoch + 1, seed, save_dir)

        # Save the model every 25 epochs
        if (epoch + 1) % 25 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
    # generating / saving after the final epoch
    generate_and_save_images(generator, epochs, seed, save_dir)
    checkpoint.save(file_prefix = checkpoint_prefix)
    
    
def generate_and_save_images(model, epoch, train_loader, save_path):
    predictions = model(train_loader)

    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow((predictions[i, :, :, :].numpy()+1)*0.5)
        plt.axis('off')

    if (epoch + 1) % 10 == 0:
        if epoch<=100:
            plt.savefig(os.path.join(save_path, 'image_at_epoch_{:04d}.png'.format(epoch+1)))
        elif (epoch + 1) % 100 == 0:
            if epoch<=1000:
                plt.savefig(os.path.join(save_path, 'image_at_epoch_{:04d}.png'.format(epoch+1)))
            elif (epoch + 1) % 1000 == 0:
                plt.savefig(os.path.join(save_path, 'image_at_epoch_{:04d}.png'.format(epoch+1)))
    plt.show()
    
    
def calculate_fid(real_embeddings, generated_embeddings):
    from scipy.linalg import sqrtm
    mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(generated_embeddings,  rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid