In [None]:
import tensorflow as tf

from tensorflow.keras.layers import Activation, Conv2D, BatchNormalization, Input, Layer, InputSpec, Add, Dropout, Lambda, UpSampling2D, Flatten, Dense, LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.optimizers import Adam

import numpy as np
import keras.backend as K
import tqdm
import datetime
import cv2
import os
import glob
from sklearn.utils import shuffle
from imutils import build_montages
import torch

from google.colab.patches import cv2_imshow
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

In [None]:
class ReflectionPadding2D(Layer):

    def __init__(self, padding=(1, 1), data_format=None, **kwargs):

        if isinstance(padding, int):
            self.padding = tuple((padding, padding))
        else:
            self.padding = tuple(padding)
        if data_format is None:
            value = K.image_data_format()
        self.data_format = value.lower()
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):

        if self.data_format == 'channels_first':
            return s[0], s[1], s[2] + 2 * self.padding[0], s[3] + 2 * self.padding[1]
        elif self.data_format == 'channels_last':
            return s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3]

    def call(self, x):

        w_pad, h_pad = self.padding
        if self.data_format == 'channels_first':
            pattern = [[0, 0], [0, 0], [h_pad, h_pad], [w_pad, w_pad]]
        elif self.data_format == 'channels_last':
            pattern = [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]]

        return tf.pad(x, pattern, 'REFLECT')


def ResBlock(input, filters, kernel_size=(3, 3), strides=(1, 1)):
    x = ReflectionPadding2D((1, 1))(input)
    x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(0.5)(x)
    x = ReflectionPadding2D((1, 1))(x)
    x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    out = Add()([input, x])
    return out

class DCGAN:

    @staticmethod
    def build_generator(image_shape, num_gen_filter, num_resblock):

        inputs = Input(shape=image_shape)

        x = ReflectionPadding2D((3, 3))(inputs)
        x = Conv2D(filters=num_gen_filter, kernel_size=(7, 7), padding='valid')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        n_downsample = 2
        for i in range(n_downsample):
            mul = 2**i
            x = Conv2D(filters=num_gen_filter * mul * 2, kernel_size=(3, 3), strides=2, padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
        
        x = Conv2D(filters=num_gen_filter * 8, kernel_size=(1, 1), strides=2, padding='same')(x)

        mul = 2**n_downsample
        for i in range(num_resblock):
            x = ResBlock(x, num_gen_filter * mul * 2)

        n_upsample = 4
        for i in range(n_upsample):
            mul = 2**(n_upsample - i)
            x = UpSampling2D()(x)
            x = Conv2D(filters=int(num_gen_filter * mul / 2), kernel_size=(3, 3), padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('relu')(x)

        x = ReflectionPadding2D((3, 3))(x)
        x = Conv2D(filters=3, kernel_size=(7, 7), padding='valid')(x)
        x = Activation('tanh')(x)

        outputs = Add()([x, UpSampling2D()(inputs)])
        outputs = Lambda(lambda z: z/2)(outputs)

        model = Model(inputs=inputs, outputs=x, name='Generator')
        return model

    @staticmethod
    def build_discriminator(image_shape, num_dis_filter):

        n_layers = 3
        inputs = Input(shape=image_shape)

        x = Conv2D(filters=num_dis_filter, kernel_size=(4, 4), strides=2, padding='same')(inputs)
        x = LeakyReLU(0.2)(x)

        nf_mult, nf_mult_prev = 1, 1
        for n in range(n_layers):
            nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
            x = Conv2D(filters=num_dis_filter * nf_mult, kernel_size=(4, 4), strides=2, padding='same')(x)
            x = BatchNormalization()(x)
            x = LeakyReLU(0.2)(x)

        nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
        x = Conv2D(filters=num_dis_filter * nf_mult, kernel_size=(4, 4), strides=1, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)
        x = Conv2D(filters=1, kernel_size=(4, 4), strides=1, padding='same')(x)
        x = Flatten()(x)
        x = Dense(1024, activation='tanh')(x)
        x = Dense(1, activation='sigmoid')(x)

        model = Model(inputs=inputs, outputs=x, name='Discriminator')
        return model

In [None]:
def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', input_shape=(64, 64, 3))
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)

In [None]:
blurImages = (torch.load('path to train blurimages').numpy() - 127.5 ) / 127.5
sharpImages = (torch.load('path to train sharpimages').numpy() - 127.5) / 127.5
bt = (torch.load('path to test blurimages').numpy() - 127.5 ) / 127.5
st = torch.load('path to test sharpimages').numpy()

In [None]:
epoch_num = 5
batch_size = 64

shape = (256, 256, 3)

y_train, x_train = shuffle(sharpImages), shuffle(blurImages)

gen = DCGAN.build_generator(shape, 64, 9)
dis = DCGAN.build_discriminator((512, 512, 3), 64)
inputs = Input(shape=shape)
gen_image = gen(inputs)
outputs = dis(gen_image)
dis_on_gen = Model(inputs=inputs, outputs=[gen_image, outputs])

dis_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
dis_on_gen_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

dis.trainable = True
dis.compile(optimizer=dis_opt, loss=wasserstein_loss)
dis.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100, 1]
dis_on_gen.compile(optimizer=dis_on_gen_opt, loss=loss, loss_weights=loss_weights)
dis.trainable = True

output_true_batch, output_false_batch = np.ones((batch_size, 1)), -np.ones((batch_size, 1))

for epoch in tqdm.tqdm(range(epoch_num)):

    print("[INFO] starting epoch {} of {}...".format(epoch + 1, epoch_num))

    dis_losses = []
    dis_on_gen_losses = []
    batchesPerEpoch = int(blurImages.shape[0] / batch_size)
    x_train = shuffle(x_train)
    y_train = shuffle(y_train)

    for i in range(batchesPerEpoch):

        image_blur_batch = x_train[i * batch_size:(i + 1) * batch_size]
        image_sharp_batch = y_train[i * batch_size:(i + 1) * batch_size]

        generated_images = gen.predict(x=image_blur_batch, batch_size=batch_size, verbose=0)

        for _ in range(5):
            dis_loss_real = dis.train_on_batch(image_sharp_batch, output_true_batch)
            dis_loss_fake = dis.train_on_batch(generated_images, output_false_batch)
            dis_loss = 0.5 * np.add(dis_loss_fake, dis_loss_real)
            dis_losses.append(dis_loss)

        dis.trainable = False

        dis_on_gen_loss = dis_on_gen.train_on_batch(image_blur_batch, [image_sharp_batch, output_true_batch])
        dis_on_gen_losses.append(dis_on_gen_loss)

        dis.trainable = True
        print("[INFO] Epoch: %d, Step: %d, discriminator_loss: %.6f, adversarial_loss: %.6f" % (epoch + 1, i, dis_loss, np.mean(dis_on_gen_loss)))

    print(np.mean(dis_losses), np.mean(dis_on_gen_losses))
    with open('log.txt', 'a+') as f:
        f.write(
            'Epoch {} - Discriminator Loss {} - GaN Loss {}\n'.format(epoch, np.mean(dis_losses),
                                                                      np.mean(dis_on_gen_losses)))

In [None]:
pred = gen.predict(bt, verbose=0)
pred = (pred * 127.5) + 127.5

In [None]:
ss = []
pp = []
for i in range(len(pred)):
    ss.append(ssim(pred[i], st[i], multichannel=True))
    pp.append(psnr(pred[i], st[i], data_range=255))

ss = np.array(ss)
print(np.mean(ss))

pp = np.array(pp)
print(np.mean(pp))