In [None]:
!pip install tensorflow==2.4.3
!rm -rf ./*
!git clone https://github.com/gunateja5465/MajorProject.git
!mv MajorProject/* ./

In [None]:
import os
import glob
import numpy as np
from PIL import Image, ImageOps
import tensorflow as tf
import matplotlib.pyplot as plt

from datasets.div2k.parameters import Div2kParameters 
from models.srgan_generator import build_srresnet
from models.pretrained import pretrained_models
from utils.prediction import get_sr_image
from utils.config import config

In [None]:
dataset_key = "bicubic_x4"

data_path = config.get("data_path", "") 

div2k_folder = os.path.abspath(os.path.join(data_path, "div2k"))

dataset_parameters = Div2kParameters(dataset_key, save_data_directory=div2k_folder)

In [None]:
def load_image(path):
    img = Image.open(path)
    
    was_grayscale = len(img.getbands()) == 1
    
    if was_grayscale or len(img.getbands()) == 4:
        img = img.convert('RGB')
    
    return was_grayscale, np.array(img)


In [None]:
model_name = "srgan"
model_key = f"{model_name}_{dataset_key}"

In [None]:
weights_directory = os.path.abspath(f"weights/{model_key}")

file_path = os.path.join(weights_directory, "generator.h5")

if not os.path.exists(file_path):
    os.makedirs(weights_directory, exist_ok=True)
    
    print("Couldn't find file: ", file_path, ", attempting to download a pretrained model")
    
    if model_key not in pretrained_models:
        print(f"Couldn't find pretrained model with key: {model_key}, available pretrained models: {pretrained_models.key()}")
    else:
        download_url = pretrained_models[model_key]
        file = file_path.split("/")[-1]
        tf.keras.utils.get_file(file, download_url, cache_subdir=weights_directory)

In [None]:
model = build_srresnet(scale=dataset_parameters.scale)

os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'

model.load_weights(weights_file)

In [None]:
results_path = f"output/{model_key}/"
os.makedirs(results_path, exist_ok=True)

In [None]:
image_paths = glob.glob("input/*")

for image_path in image_paths:
    print(image_path)
    was_grayscale, lr = load_image(image_path)
    
    sr = get_sr_image(model, lr)
        
    if was_grayscale:
        sr = ImageOps.grayscale(sr)
    
    image_name = image_path.split("/")[-1]
    sr.save(f"{results_path}{image_name}" )

In [None]:
# zip files for download from colab

!zip -r images.zip output

In [None]:
from math import log10, sqrt 

import cv2 
from skimage.metrics import structural_similarity as ssim1
# from skimage.metrics import mean_squared_error as mse

  

def PSNR(original, compressed): 

    mse = np.mean((original - compressed) ** 2) 

    if(mse == 0):  # MSE is zero means no noise is present in the signal . 

                  # Therefore PSNR have no importance. 

        return 100

    max_pixel = 255.0

    psnr = 20 * log10(max_pixel / sqrt(mse)) 

    return psnr
def mse(original, compressed): 
    m = np.mean((original - compressed) ** 2)*0.01
    return m

def ssim(img1, img2):
  return ssim1(img1,img2,data_range=255,multichannel=True)

cols = 3
if model_name == "srgan":
  cols = 3
  print("PSNR-HR  PSNR-SRGAN  MSE-SRGAN    SSIM-SRGAN")
else:
  cols = 4
  print("PSNR-HR PSNR-SRGAN  PSNR-SRRESNET  MSE-SRGAN  MSE-SRRESNET   SSIM-SRGAN  SSIM-SRRESNET")

num_img=len(os.listdir("input"))
for i in range(num_img):
  fig, axes = plt.subplots(nrows=1, ncols=cols, figsize=(15,15))
  original = cv2.imread(f"input/{i}.png")
 
  resized = cv2.resize(original, (384, 384))
  axes[0].imshow(resized)
  hr_normal = cv2.imread(f"HR/{i}.png", 1)
  axes[2].imshow(hr_normal)
  hr_gan = cv2.imread(f"output/srgan_bicubic_x4/{i}.png", 1)
  axes[1].imshow(hr_gan)

  if model_name == "srgan":
    print(f"{PSNR(hr_normal,hr_normal):.3f}    {PSNR(hr_normal, hr_gan):.3f}       {mse(hr_normal,hr_gan):.3f}        {ssim(hr_normal,hr_gan):.3f} ")
  else:
    hr_resnet = cv2.imread(f"output/srresnet_bicubic_x4/{i}.png", 1)
    axes[3].imshow(hr_resnet)
    print(f"{PSNR(hr_normal,hr_normal):.3f}    {PSNR(hr_normal, hr_gan):.3f}        {PSNR(hr_normal, hr_resnet):.3f}     {mse(hr_normal,hr_gan):.3f}       {mse(hr_normal,hr_resnet):.3f}            {ssim(hr_normal,hr_gan):.3f}         {ssim(hr_normal,hr_resnet):.3f}")

