In [None]:
import tensorflow as tf
import math
from tqdm import tqdm
import cv2
import numpy as np
import time
from PIL import Image
import matplotlib.pyplot as plt
import torch

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
def pad_to_fit(image, pad_height, pad_width):
    inp_h, inp_w,c = image.shape
    num_horizontal = math.ceil(inp_h/pad_height)
    num_vertical = math.ceil(inp_w/pad_width)
    
    out_h = num_horizontal*pad_height
    out_w = num_vertical*pad_width
    
    out_img = np.zeros((out_h, out_w, c), dtype=image.dtype)    
    out_img[0:inp_h, 0:inp_w, :] = image
    
    return out_img, num_horizontal, num_vertical

def get_tiles(image, tile_height, tile_width):
    
    image_height, image_width, _ = image.shape
    serial_no = 0
    for column_idx, i in enumerate(range(0, image_height, tile_height), 1):
        for row_idx, j in enumerate(range(0, image_width, tile_width), 1):
            serial_no += 1
            top = i
            left = j
            bottom = i+tile_height
            right = j+tile_width
            yield (column_idx, row_idx), (top, left, bottom, right), image[i:i+tile_height, j:j+tile_width, :]
    

def common(img):
    h_input, w_input = img.shape[0:2]
    img = img.astype(np.float32)
    if np.max(img) > 256:  # 16-bit image
        max_range = 65535
        print('\tInput is a 16-bit image')
    else:
        max_range = 255
    img = img / max_range
    if len(img.shape) == 2:  # gray image
        img_mode = 'L'
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    elif img.shape[2] == 4:  # RGBA image with alpha channel
        img_mode = 'RGBA'
        alpha = img[:, :, 3]
        img = img[:, :, 0:3]
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
    else:
        img_mode = 'RGB'
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
    return img, h_input, w_input


def preprocess_image(img, pre_pad=10, mod_scale=4):
    img, h_input, w_input = common(img)
    (left, right, top, bottom) = (0, pre_pad, 0, pre_pad)
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_REFLECT, None, value=0)
   
    mod_pad_h, mod_pad_w = 0, 0
    h,w,_ = img.shape
    if (h % mod_scale != 0):
        mod_pad_h = (mod_scale - h % mod_scale)
    if (w % mod_scale != 0):
        mod_pad_w = (mod_scale - w % mod_scale)

    (left, right, top, bottom) = (0, mod_pad_w, 0, mod_pad_h)
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_REFLECT, None, value=0)
    img = np.transpose(img, (2,0,1))
    img = np.expand_dims(img, 0)
    
    return img, mod_pad_h, mod_pad_w


def post_process(output, mod_pad_h, mod_pad_w, mod_scale=4, pre_pad=10):
    scale=1 # hardcoded for model_scale=4
    # remove extra pad
    output = np.squeeze(output)
    output = np.transpose(output, (1,2,0))
    if mod_scale is not None:
        h, w, _ = output.shape
        output = output[0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale, :]
    # remove prepad
    if pre_pad != 0:
        # _, _, h, w = output.size()
        h, w, _ = output.shape
        output = output[0:h - pre_pad * scale, 0:w - pre_pad * scale, :]
    # unsqueeze to remove batch
    output = np.clip(output, 0, 1)
    # convert to channel last
    output = (output * 255.0).round().astype(np.uint8)
    
    return output


def run_inference(image_path, model, input_height, input_width, NETSCALE = 4):
    image = cv2.imread(image_path)
    
    # calculate original height and wisth to crop the final result image to required output dimension
    orig_height, orig_width, _ = image.shape
    image, num_horizontal, num_vertical = pad_to_fit(image, pad_height=input_height, pad_width=input_width)
    total_tiles = num_horizontal*num_vertical
    height, width, _ = image.shape
    print(f"padding done ({orig_height} x {orig_width}) -> ({height} x {width})")
        
    
    start = time.time()
    
    results = []
    
    num_processed = 0
    print("")
    
    for (column_idx, row_idx), (top, left, bottom, right), image_tile in get_tiles(image, tile_height=input_height, tile_width=input_width):
        in_h, in_w, _ = image_tile.shape
   
        input_data, mod_pad_h, mod_pad_w = preprocess_image(image_tile)
        
        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])
        res_array = post_process(output_data, mod_pad_h, mod_pad_w)
        res_array = res_array[:in_h*NETSCALE, :in_w*NETSCALE, :]
        out_h, out_w, _ = res_array.shape
        
        h_ratio = out_h/in_h
        w_ratio = out_w/in_w
        
    
        results.append({"array":res_array, "coords":(int(top*h_ratio), int(left*h_ratio), int(bottom*h_ratio), int(right*w_ratio)), "position":(column_idx, row_idx), })
        num_processed += 1
        print(f"\rprocessed {num_processed}/{total_tiles} tiles")#, end="\r")

    end = time.time()
    print("")

    print("inference took {} seconds".format(end-start))
    
    return results, num_horizontal, num_vertical, orig_height, orig_width

    
        


# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="tflite_model/srmd_android.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]['shape']
output_shape = output_details[0]['shape']

print(input_shape)
print(output_shape)

In [None]:

IMAGE_PATH = f"images/people_{i+1}.jpg"
result, num_horizontal, num_vertical, orig_height, orig_width = run_inference(IMAGE_PATH, interpreter, input_height=179, input_width=179)
NETSCALE = 4
sample = result[0]["array"]
sample_height, sample_width, _ = sample.shape
result_array = np.zeros((sample_height*num_horizontal, sample_width*num_vertical, 3), dtype=np.uint8)
index = 0

for r in result:
    array = r["array"]
    height, width, _ = array.shape
    top, left, bottom, right = r["coords"]
    pos = r["position"]
    
    result_array[top:bottom, left:right, :] = array
        
result_array = result_array[0:orig_height*NETSCALE, 0:orig_width*NETSCALE, :]
result_array = cv2.cvtColor(result_array, cv2.COLOR_BGR2RGB)
cv2.imwrite(f"real_esrgan_output/people_{i+1}.jpg", result_array)