# Super Resolution

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pathlib
import os
from functools import reduce
from random import random

mpl.rcParams['figure.dpi'] = 600


In [None]:
dataset_name = 'flickr30k_images'
data_dir = f'/root/.keras/datasets/{dataset_name}'

if not os.path.isdir(data_dir):
  dataset_url = f'https://datasets-349058029.s3.us-west-2.amazonaws.com/flickr/{dataset_name}.zip'
  tf.keras.utils.get_file(origin=dataset_url, extract=True)

data_dir = pathlib.Path(data_dir)
print(f"{len(list(data_dir.glob('*/*.jpg')))} images in dataset")

batch_size = 20
resolution_down_factor = 0.25
patch_size = 256
image_size = 256

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  labels=None,
  crop_to_aspect_ratio=True,
  shuffle=True,
  image_size=(image_size, image_size),
  batch_size=batch_size)

# reduce the dataset size for faster processing
ds = ds.take(int(0.001 * len(ds)))

print(f'dataset size: {len(ds)}')

In [None]:
reduce_resolution = tf.keras.Sequential([
  layers.Resizing(int(image_size*resolution_down_factor), int(image_size*resolution_down_factor)),
  layers.Resizing(image_size, image_size),])

scale = layers.Rescaling(1./255)

def transforms(input):
  y = scale(input)
  X = tf.cast(reduce_resolution(input), tf.uint8) 
  return (X, y)


AUTOTUNE = tf.data.AUTOTUNE

test_ds = ds.map(transforms, num_parallel_calls=AUTOTUNE).prefetch(buffer_size=AUTOTUNE)

In [None]:
# Grab a batch of images
X_batch, y_batch = iter(test_ds).get_next()

In [None]:
print(len(X_batch), len(y_batch))

In [None]:
def process(image, model, save_path=None):
  out = np.clip(np.squeeze(model.predict(np.array([image]))), -1.0, 1.0)
 
  out = out + 1
  out = out / 2
  out = out * 255
  out = out.astype(np.uint8)

  if save_path is not None:
    plt.imsave(save_path, out)

  return out

In [None]:
# Show example images
fig = plt.figure(figsize=(4, 2), dpi=300)
columns = 4
rows = 2
plt.title('input vs target', fontsize=4)
plt.axis('off')
for i in range(1, columns*rows+1, 2):
    input = X_batch[i]
    target = y_batch[i]
    fig.add_subplot(rows, columns, i)
    plt.axis('off')
    plt.imshow(input)
    fig.add_subplot(rows, columns, i+1)
    plt.axis('off')
    plt.imshow(target)
plt.show()

In [None]:
rand_batch_int = int(random() * batch_size)

In [None]:
def plot_results(model_arch):
  
  save_path = f'./result_imgs/{model_arch}'
  os.makedirs(save_path, exist_ok=True)

  input = X_batch[rand_batch_int]
  target = y_batch[rand_batch_int]

  # if input_scaling == 'tanh':
  #   input = layers.Rescaling(1./127.5, offset=-1)(input)

  # if input_scaling == '0-1':
  #   input = layers.Rescaling(1./255)(input)

  plt.imsave(f'{save_path}/original.jpg', (input * 255).numpy().astype(np.uint8))

  model = tf.keras.models.load_model(f'models/{model_arch}.h5', compile=False)
  sr_image = process(X_batch[rand_batch_int], model, f'{save_path}/enhanced.jpg')


  # Display results
  fig = plt.figure(figsize=(2, 2), dpi=600)
  columns, rows = (2, 2)

  fig.add_subplot(rows, columns, 1)
  plt.title('Input Image', fontsize=4)
  plt.axis('off')
  plt.imshow(input)

  fig.add_subplot(rows, columns, 2)
  plt.title('Ground Truth', fontsize=4)
  plt.axis('off')
  plt.imshow(target)

  fig.add_subplot(rows, columns, 3)
  plt.title('Result', fontsize=4)
  plt.axis('off')
  plt.imshow(sr_image)

  plt.show()


In [None]:
plot_results('generator_4Xzoom_plossX0')

In [None]:
plot_results('generator_4Xzoom_plossX0-1')

In [None]:
def load_img(image_path):
  img = tf.io.read_file(image_path)
  img = tf.image.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.uint8)
  img = img[tf.newaxis, :]
  return img

In [None]:
def enhance_w_overlap_trim(image_path, scale, overlap, trim, model, save_path):
  model_output_dim = 256
  trimmed_model_output_dim = model_output_dim - trim
  img = load_img(image_path)
  b, height, width, d = img.shape
  numTiles = min([height, width]) // trimmed_model_output_dim
  newDim = numTiles * scale
  upscaledDim = newDim*model_output_dim + (overlap * (numTiles - 1))
  upscaled_img = tf.squeeze(tf.image.resize(
    img,
    [upscaledDim, upscaledDim],
    preserve_aspect_ratio=True))

  crops = []
  for h in range(newDim):
    for w in range(newDim):
      crop = tf.slice(upscaled_img, [max(h*model_output_dim - h*overlap, 0), max(w*model_output_dim - w*overlap, 0), 0], [model_output_dim, model_output_dim, -1])
      crops.append(crop)

  crops = model.predict(np.array(crops))
  crops = [np.squeeze(c) for c in crops]
  crops = [tf.slice(c, [trim, trim, 0], [trimmed_model_output_dim, trimmed_model_output_dim, -1]) for c in crops]

  overlap = overlap - trim

  rows = []
  for i in range(0, newDim*newDim, newDim):
    row = crops[i]
    for crop in crops[i+1:i+newDim]:
      left = tf.slice(row, [0, 0, 0], [-1, row.shape[1] - overlap, -1])
      right = tf.slice(crop, [0, overlap, 0], [-1, -1, -1])
      overlap_left = tf.slice(row, [0, row.shape[1] - overlap, 0], [-1, -1, -1])
      overlap_right = tf.slice(crop, [0, 0, 0], [-1, overlap, -1])
      overlap_avg = (overlap_left + overlap_right) / 2
      row = tf.concat((left, overlap_avg, right), axis=1)
    rows.append(row)

  
  whole = rows[0]
  for row in rows[1:]:
    top = tf.slice(whole, [0, 0, 0], [whole.shape[0] - overlap, -1, -1])
    bottom = tf.slice(row, [overlap, 0, 0], [-1, -1, -1])
    overlap_top = tf.slice(whole, [whole.shape[0] - overlap, 0, 0], [-1, -1, -1])
    overlap_bottom = tf.slice(row, [0, 0, 0], [overlap, -1, -1])
    overlap_avg = (overlap_top + overlap_bottom) / 2
    whole = tf.concat((top, overlap_avg, bottom), axis=0)


  whole = whole + 1
  whole = whole / 2
  whole = whole * 255
  whole = np.clip(whole, 0, 255).astype(np.uint8)

  plt.title('Original')
  plt.imshow(tf.cast(upscaled_img, tf.uint8))
  plt.show()

  plt.title(f'{scale}x enhanced')
  plt.imsave(save_path, whole)
  plt.imshow(whole)
  plt.show()

# model = tf.keras.models.load_model('models/generator_4Xzoom_plossX0-1.h5', compile=False)
# enhance_w_overlap_trim('./dog.jpg', 4, 15, 10, model, './dog_generator_4Xzoom_plossX0-1.jpg')