In [None]:
import tensorflow as tf
from tensorflow import keras
from skimage import io
import numpy as np
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from PIL import Image
import onnx
from onnx_tf.backend import prepare
import os
import time
import tensorflow_hub as hub
os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
import torch
from torch.nn import functional as F
import cv2
import os.path
import logging
import re
from collections import OrderedDict
from scipy.io import loadmat
from utils.network_srmd import SRMD as net 
from utils import utils_deblur
from utils import utils_sisr as sr
from utils import utils_logger
from utils import utils_image as util
from utils import utils_model

In [None]:
def preprocess_srmd(img):
    img_name, ext = os.path.splitext(os.path.basename(img))
    img_L = util.imread_uint(img, n_channels=3)
    img_L = cv2.resize(img_L, (179,179))
    img_L = util.uint2single(img_L)
    noise_level_model = 0
    srmd_pca_path = os.path.join('kernels', 'srmd_pca_matlab.mat')
    kernel = utils_deblur.fspecial('gaussian', 15, 0.01)
    P = loadmat(srmd_pca_path)['P']
    degradation_vector = np.dot(P, np.reshape(kernel, (-1), order="F"))
    degradation_vector = np.append(degradation_vector, noise_level_model/255.)
    degradation_vector = torch.from_numpy(degradation_vector).view(1, -1, 1, 1).float()
    img_L = util.single2tensor4(img_L)
    degradation_map = degradation_vector.repeat(1, 1, img_L.size(-2), img_L.size(-1))
    img_L = torch.cat((img_L, degradation_map), dim=1)
    img_L = img_L.to("cpu")
    return img_L

def post_process(output, mod_pad_h, mod_pad_w, mod_scale=4, pre_pad=10):
    scale=1 # hardcoded for model_scale=4
    # remove extra pad
    if mod_scale is not None:
        _, _, h, w = output.size()
        output = output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
    # remove prepad
    if pre_pad != 0:
        _, _, h, w = output.size()
        output = output[:, :, 0:h - pre_pad * scale, 0:w - pre_pad * scale]
    # unsqueze to remove batch
    output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
    # convert to channel last
    output = np.transpose(output, (1, 2, 0))
    output = (output * 255.0).round().astype(np.uint8)
    
    return output

def preprocess(img_path, pre_pad=10, mod_scale=4):
    img = cv2.imread(img_path)
    h_input, w_input = img.shape[0:2]
    # img: numpy
    img = img.astype(np.float32)
    if np.max(img) > 256:  # 16-bit image
        max_range = 65535
        print('\tInput is a 16-bit image')
    else:
        max_range = 255
    img = img / max_range
    if len(img.shape) == 2:  # gray image
        img_mode = 'L'
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    elif img.shape[2] == 4:  # RGBA image with alpha channel
        img_mode = 'RGBA'
        alpha = img[:, :, 3]
        img = img[:, :, 0:3]
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if alpha_upsampler == 'realesrgan':
            alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
    else:
        img_mode = 'RGB'
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
    img = img.unsqueeze(0).to("cuda")
    img = F.pad(img, (0, pre_pad, 0, pre_pad), 'reflect')
    
    mod_pad_h, mod_pad_w = 0, 0
    _, _, h, w = img.size()
    if (h % mod_scale != 0):
        mod_pad_h = (mod_scale - h % mod_scale)
    if (w % mod_scale != 0):
        mod_pad_w = (mod_scale - w % mod_scale)
    img = F.pad(img, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
    
    return img, mod_pad_h, mod_pad_w

def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
def save_image(image, filename):
  """
    Saves unscaled Tensor Images.
    Args:
      image: 3D image tensor. [height, width, channels]
      filename: Name of the file to save.
  """
  if not isinstance(image, Image.Image):
    image = tf.clip_by_value(image, 0, 255)
    image = Image.fromarray(tf.cast(image, tf.uint8).numpy())
  image.save("%s.jpg" % filename)
  print("Saved as %s.jpg" % filename)
    
%matplotlib inline
def plot_image(image, title=""):
  """
    Plots images from image tensors.
    Args:
      image: 3D image tensor. [height, width, channels].
      title: Title to display in the plot.
  """
  image = np.asarray(image)
  image = tf.clip_by_value(image, 0, 255)
  image = Image.fromarray(tf.cast(image, tf.uint8).numpy())
  plt.imshow(image)
  plt.axis("off")
  plt.title(title)

In [None]:
onnx_path = "srmd_x4.onnx"
tf_path = "tf_model/"

onnx_model = onnx.load(onnx_path)  # load onnx model
onnx.checker.check_model(onnx_model)

tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_path)

In [None]:
tf_path = "tf_model/"
model = tf.saved_model.load(tf_path)
print(list(model.signatures.keys()))  # ["serving_default"]

infer = model.signatures["serving_default"]
print(infer.structured_outputs)
# infer.inputs

In [None]:
IMAGE_PATH = "C:/Users/ayush/Rizzle/52_images_super_resolution_testing/Images/360p_images/people_4.jpg"
preprocessed_image, mod_pad_h, mod_pad_w = preprocess(IMAGE_PATH)
srmd_preprocessed_img = preprocess_srmd(IMAGE_PATH)
# print(hr_image.shape)
hr_np = preprocessed_image.cpu().numpy()
print(preprocessed_image.shape)

In [None]:
# # Plotting Original Resolution image
post_process_output = post_process(preprocessed_image, mod_pad_h, mod_pad_w)
print(post_process_output.shape)
print(tf.squeeze(preprocessed_image.cpu().numpy()).shape)
plot_image(post_process_output, title="Original Image")
save_image(post_process_output, filename="Original Image")

In [None]:
start = time.time()
x = tf.convert_to_tensor(srmd_preprocessed_img.numpy(), dtype=tf.float32, name="input0")
fake_image = infer(x)
print("Time Taken: %f" % (time.time() - start))
fake_image = tf.squeeze(fake_image["output0"])


In [None]:
fake_image = tf.expand_dims(fake_image, 0)
fake_image = torch.from_numpy(fake_image.numpy())

In [None]:
post_process_output = post_process(fake_image, mod_pad_h, mod_pad_w)

In [None]:
post_process_output.shape

In [None]:
# Plotting Super Resolution Image
plot_image(post_process_output, title="Super Resolution")
save_image(post_process_output, filename="Super Resolution")