# Super Resolution

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pathlib
import os
from functools import reduce
from random import random

mpl.rcParams['figure.dpi'] = 600


In [None]:
dataset_name = 'flickr30k_images'
data_dir = f'/root/.keras/datasets/{dataset_name}'

if not os.path.isdir(data_dir):
  dataset_url = f'https://datasets-349058029.s3.us-west-2.amazonaws.com/flickr/{dataset_name}.zip'
  tf.keras.utils.get_file(origin=dataset_url, extract=True)

data_dir = pathlib.Path(data_dir)
print(f"{len(list(data_dir.glob('*/*.jpg')))} images in dataset")

batch_size = 20
resolution_down_factor = 0.25
patch_size = 224
image_size = 224

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  labels=None,
  crop_to_aspect_ratio=True,
  shuffle=True,
  image_size=(image_size, image_size),
  batch_size=batch_size)

# reduce the dataset size for faster processing
ds = ds.take(int(0.001 * len(ds)))

print(f'dataset size: {len(ds)}')

In [None]:
reduce_resolution = tf.keras.Sequential([
  layers.Resizing(int(image_size*resolution_down_factor), int(image_size*resolution_down_factor)),
  layers.Resizing(image_size, image_size),])

scale = layers.Rescaling(1./255)

def transforms(input):
  y = scale(input)
  X = tf.cast(reduce_resolution(input), tf.uint8) 
  return (X, y)


AUTOTUNE = tf.data.AUTOTUNE

test_ds = ds.map(transforms, num_parallel_calls=AUTOTUNE).prefetch(buffer_size=AUTOTUNE)

In [None]:
# Grab a batch of images
X_batch, y_batch = next(test_ds.as_numpy_iterator())

In [None]:
def process(image, model, save_path=None):
  # crops = []
  # for h in range(3):
  #   for w in range(3):
  #     crops.append(tf.image.crop_to_bounding_box(image, h*patch_size, w*patch_size, patch_size, patch_size))

  # crops = model.predict(np.array(crops))
  # crops = [np.squeeze(np.clip(c, 0.0, 1.0)) for c in crops]

  # top = tf.concat(crops[:3], axis=1)
  # middle = tf.concat(crops[3:6], axis=1)
  # bottom = tf.concat(crops[6:], axis=1)
  # whole = tf.concat((top, middle, bottom), axis=0)

  whole = np.clip(np.squeeze(model.predict(np.array([image]))), 0.0, 1.0)

  if save_path is not None:
    plt.imsave(save_path, whole)

  return whole

In [None]:
# Show example images
fig = plt.figure(figsize=(4, 2), dpi=300)
columns = 4
rows = 2
plt.title('input vs target', fontsize=4)
plt.axis('off')
for i in range(1, columns*rows+1, 2):
    input = X_batch[i]
    target = y_batch[i]
    fig.add_subplot(rows, columns, i)
    plt.axis('off')
    plt.imshow(input)
    fig.add_subplot(rows, columns, i+1)
    plt.axis('off')
    plt.imshow(target)
plt.show()

In [None]:
rand_batch_int = int(random() * batch_size)

In [None]:
def plot_epochs(model_arch, epochs=4):
  
  save_path = f'./result_imgs/{model_arch}'
  os.makedirs(save_path, exist_ok=True)

  input = X_batch[rand_batch_int]
  target = y_batch[rand_batch_int]

  plt.imsave(f'{save_path}/0.jpg', input)

  model = tf.keras.models.load_model(f'models/{model_arch}_epoch_1.h5', compile=False)
  sr_image1 = process(X_batch[rand_batch_int], model, f'{save_path}/a_epoch_1.jpg')

  model = tf.keras.models.load_model(f'models/{model_arch}_epoch_2.h5', compile=False)
  sr_image2 = process(X_batch[rand_batch_int], model, f'{save_path}/b_epoch_2.jpg')

  model = tf.keras.models.load_model(f'models/{model_arch}_epoch_3.h5', compile=False)
  sr_image3 = process(X_batch[rand_batch_int], model, f'{save_path}/c_epoch_3.jpg')

  model = tf.keras.models.load_model(f'models/{model_arch}_epoch_4.h5', compile=False)
  sr_image4 = process(X_batch[rand_batch_int], model, f'{save_path}/d_epoch_4.jpg')

  if epochs == 6:
    model = tf.keras.models.load_model(f'models/{model_arch}_epoch_5.h5', compile=False)
    sr_image5 = process(X_batch[rand_batch_int], model, f'{save_path}/c_epoch_5.jpg')

    model = tf.keras.models.load_model(f'models/{model_arch}_epoch_6.h5', compile=False)
    sr_image6 = process(X_batch[rand_batch_int], model, f'{save_path}/d_epoch_6.jpg')


  # Display results
  if epochs == 4:
    fig = plt.figure(figsize=(2, 3), dpi=600)
    columns, rows = (2, 3)
  elif epochs == 6:
    fig = plt.figure(figsize=(2, 4), dpi=600)
    columns, rows = (2, 4)

  fig.add_subplot(rows, columns, 1)
  plt.title('Input Image', fontsize=4)
  plt.axis('off')
  plt.imshow(input)

  fig.add_subplot(rows, columns, 2)
  plt.title('Ground Truth', fontsize=4)
  plt.axis('off')
  plt.imshow(target)

  fig.add_subplot(rows, columns, 3)
  plt.title('Epoch 1', fontsize=4)
  plt.axis('off')
  plt.imshow(sr_image1)

  fig.add_subplot(rows, columns, 4)
  plt.title('Epoch 2', fontsize=4)
  plt.axis('off')
  plt.imshow(sr_image2)

  fig.add_subplot(rows, columns, 5)
  plt.title('Epoch 3', fontsize=4)
  plt.axis('off')
  plt.imshow(sr_image3)

  fig.add_subplot(rows, columns, 6)
  plt.title('Epoch 4', fontsize=4)
  plt.axis('off')
  plt.imshow(sr_image4)

  if epochs == 6:
    fig.add_subplot(rows, columns, 7)
    plt.title('Epoch 5', fontsize=4)
    plt.axis('off')
    plt.imshow(sr_image5)

    fig.add_subplot(rows, columns, 8)
    plt.title('Epoch 6', fontsize=4)
    plt.axis('off')
    plt.imshow(sr_image6)

  plt.show()


In [None]:
# plot_epochs('in224_randcrop_x4zoom_plossX0_gramX0', epochs=6)

In [None]:
# plot_epochs('in224_randcrop_x4zoom_plossX0-001_gramX1')


In [None]:
# plot_epochs('in224_randcrop_x4zoom_plossX0-001_gramX0-001', epochs=6)

In [None]:
plot_epochs('in224_randcrop_x4zoom_plossX0-01_gramX0-01', epochs=6) # stuck between this one 

In [None]:
plot_epochs('in224_randcrop_x4zoom_plossX0-1_gramX0-1', epochs=6)

In [None]:
plot_epochs('in224_randcrop_x4zoom_plossX0-1_gramX0-001', epochs=6) # and this one


In [None]:
plot_epochs('in224_randcrop_x4zoom_plossX0-1_gramX0-00001', epochs=6)

In [None]:
plot_epochs('in224_randcrop_x4zoom_plossX0-1_gramX0-01', epochs=6) # see if this one is a happy medium

In [None]:
def load_img(image_path):
  img = tf.io.read_file(image_path)
  img = tf.image.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.uint8)
  img = img[tf.newaxis, :]
  return img

In [None]:
def enhance(image_path, scale, model, save_path):
  model_output_dim = 224
  img = load_img(image_path)
  b, height, width, d = img.shape
  numTiles = min([height, width]) // model_output_dim
  newDim = numTiles * scale
  upscaledDim = newDim*model_output_dim
  upscaled_img = tf.squeeze(tf.image.resize(
    img,
    [upscaledDim, upscaledDim],
    preserve_aspect_ratio=True))
  
  crops = []
  for h in range(newDim):
    for w in range(newDim):
      crop = tf.image.crop_to_bounding_box(upscaled_img, h*model_output_dim, w*model_output_dim, model_output_dim, model_output_dim)
      crops.append(crop)

  crops = model.predict(np.array(crops))
  crops = [np.squeeze(np.clip(c, 0.0, 1.0)) for c in crops]

  rows = []
  for i in range(0, newDim*newDim, newDim):
    rows.append(tf.concat(crops[i:i+newDim], axis=1))

  
  whole = tf.concat(rows, axis=0)

  plt.title('Original')
  plt.imshow(tf.cast(upscaled_img, tf.uint8))
  plt.show()

  plt.title('4x enhanced')
  plt.imsave(save_path, whole.numpy())
  plt.imshow(whole)
  plt.show()




In [None]:

model = tf.keras.models.load_model('models/in224_randcrop_x4zoom_plossX0-1_gramX0-01_epoch_6.h5', compile=False)
enhance('./dog.jpg', 4, model, './dog_enhanced5.jpg')