## CONTRASTIVE BLIND SUPER RESOLUTION

Importing Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as img
import cv2
import skimage

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
from keras.layers import Input

In [None]:
import os
from sklearn.feature_extraction.image import extract_patches_2d as patch_ex


### Data Preparation

In [None]:
#Preparing LR images according to equation 1.
#Prepares LR images and saves it to a folder

def Prep_LR(HR_path,LR_path,HR_patch_path, LR_patch_path,SR_scale):
    
    #path is source of Training images
    HR_filenames = os.listdir(HR_path)
    
    
    #Loop that takes an image, synthesizes its LR and saves it
    
    for file in HR_filenames :
        
        HR = img.imread(HR_path + file)
        
        #GAussian Blur
        gaussian_blurred = cv2.GaussianBlur(HR,(0,0),4.0)
        
        #Bicubic Downsampling
        dim  = (int((HR.shape[0]/int(SR_scale))),int((HR.shape[1]/int(SR_scale))))
        bicubic_downsampled = cv2.resize(gaussian_blurred,dim, interpolation = cv2.INTER_CUBIC)
        
        
        #Adding Noise
        LR = skimage.util.random_noise(bicubic_downsampled)
        
        #SAving the LR Image
        plt.imsave(LR_path + file ,LR, format= 'png')
        
        
        #Patch extraction : 2 random patches from each image is cropped out and saved
        
        LR_patches = patch_ex(LR, (64,64), max_patches=2, random_state=23)
        plt.imsave(LR_patch_path + 'p1' + file, LR_patches[0], format = 'png')
        plt.imsave(LR_patch_path + 'p2' + file, LR_patches[1], format = 'png')

        HR_patches = patch_ex(HR, (64,64), max_patches=2, random_state=23)
        plt.imsave(HR_patch_path + 'p1' + file, HR_patches[0], format = 'png')
        plt.imsave(HR_patch_path + 'p2' + file, HR_patches[1], format = 'png')
        
        
        
                                            
                            
    

In [None]:
H = '/home/bastin/PROJECT-main/Data/Train/HR/'
L = '/home/bastin/PROJECT-main/Data/Train/LR/'
H_patch = '/home/bastin/PROJECT-main/Data/Train/Patch_HR/'
L_patch = '/home/bastin/PROJECT-main/Data/Train/Patch_LR/'
scale = 4

Prep_LR(H,L,H_patch,L_patch,scale)

 ### Building the Model

In [None]:
#Defining the Contrastive loss (L_degrad)


def ContraLoss(zi, zj, tau=1):
    
    
    z = tf.cast(tf.concat((zi, zj), 0), dtype=tf.float32)
    loss = 0
    for k in range(zi.shape[0]):
        # Numerator (compare i,j & j,i)
        i = k
        j = k + zi.shape[0]
        # Instantiate the cosine similarity loss function
        cosine_sim = tf.keras.losses.CosineSimilarity(axis=-1, reduction=tf.keras.losses.Reduction.NONE)
        sim = tf.squeeze(- cosine_sim(tf.reshape(z[i], (1, -1)), tf.reshape(z[j], (1, -1))))
        numerator = tf.math.exp(sim / tau)

        # Denominator (compare i & j to all samples apart from themselves)
        sim_ik = - cosine_sim(tf.reshape(z[i], (1, -1)), z[tf.range(z.shape[0]) != i])
        sim_jk = - cosine_sim(tf.reshape(z[j], (1, -1)), z[tf.range(z.shape[0]) != j])
        denominator_ik = tf.reduce_sum(tf.math.exp(sim_ik / tau))
        denominator_jk = tf.reduce_sum(tf.math.exp(sim_jk / tau))

        # Calculate individual and combined losses
        loss_ij = - tf.math.log(numerator / denominator_ik)
        loss_ji = - tf.math.log(numerator / denominator_jk)
        loss += loss_ij + loss_ji
    
    # Divide by the total number of samples
    loss /= z.shape[0]

    return loss

#### Encoder

In [None]:
Encoder = keras.models.Sequential()


Encoder.add(layers.Conv2D(input_shape = (512,512,3), filters = 64, kernel_size = 3,))
Encoder.add(layers.BatchNormalization())
Encoder.add(layers.LeakyReLU(0.1))
            
Encoder.add(layers.Conv2D(filters = 128, kernel_size = 3))
Encoder.add(layers.BatchNormalization())
Encoder.add(layers.LeakyReLU(0.1))
            
Encoder.add(layers.Conv2D(filters = 256, kernel_size = 3))
Encoder.add(layers.BatchNormalization())
Encoder.add(layers.LeakyReLU(0.1))

Encoder.add(layers.AveragePooling2D(1))

In [None]:
projection_head = keras.models.Sequential()

projection_head.add(layers.Flatten())
projection_head.add(layers.Dense(256))
projection_head.add(layers.LeakyReLU(0.1))
projection_head.add(layers.Dense(256))


In [None]:
Degradation_Encoder = keras.models.Sequential(Encoder,projection_head)

### Super Resolution Generator

In [None]:
#recieves a concatenated tensor x, where x[0] = feature map, and x[1] = degardation representation
# shape of feature map = batchsize x channels x height x width
# shape of degradation representation = batchsize x channels

input1 = layers.Input(shape = (512,512,1), batch_size = 32)
input2 = layers.Input(shape =(None,1) , batch_size = 32)

#resizing the degradation representation
resized_deg_rep = layers.Dense(64)(input2)
resized_deg_rep = layers.LeakyReLU(0.1)(resized_deg_rep)
resized_deg_rep = layers.Dense(64)(resized_deg_rep)

#Convolutions on input image
feat = layers.Conv2D(filters= 64, kernel_size = 3)(input1)
feat = layers.LeakyReLU(0.1)(feat)
feat = layers.Conv2D(filters= 64, kernel_size = 3)(feat)
feat = layers.LeakyReLU(0.1)(feat)
out = feat + resized_deg_rep


out = layers.Conv2D(filters= 64, kernel_size = 3)(out)
out = layers.LeakyReLU(0.1)(out)
out = layers.Conv2D(filters= 64, kernel_size = 3)(out)
out = layers.LeakyReLU(0.1)(out)
out = out + resized_deg_rep

out = layers.Conv2D(filters= 64, kernel_size = 3)(out)
out = layers.LeakyReLU(0.1)(out)
out = layers.Conv2D(filters= 64, kernel_size = 3)(out)
out = layers.LeakyReLU(0.1)(out)
out = out + resized_deg_rep

#Upsampling
out = layers.Flatten(out)
out = layers.Dense(512*512)(out)

Generator = keras.Model(inputs = [input1,input2], outputs = out)
Generator.compile(loss = )



### Training

In [None]:
priming_epochs = 50
epochs = 300
learning_rate = 0.01


train_datagen = keras.preprocessing.image.ImageDataGenerator()
train_generator = train_datagen.flow_from_directory('HR_path', target_size=(512,512),batch_size=32)

test_datagen = keras.preprocessing.image.ImageDataGenerator()
test_generator = test_datagen.flow_from_directory('test_path', target_size=(512,512),batch_size=32)


for epoch in epochs:
    
    if epoch <= priming_epochs:
        Degradation_Encoder.fit(train_generator,epochs=1,validation_data=test_generator)


    else :
        Degradation_Encoder.fit(train_generator,epochs=1,validation_data=test_generator)
        Generator.fit(train_generator,epochs = 1, validation_data = test_generator)
        
        

