# Super Resolution

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pathlib
import os
from functools import reduce
from random import random


In [None]:
dataset_name = 'flickr30k_images'
data_dir = f'/root/.keras/datasets/{dataset_name}'

if not os.path.isdir(data_dir):
  dataset_url = f'https://datasets-349058029.s3.us-west-2.amazonaws.com/flickr/{dataset_name}.zip'
  tf.keras.utils.get_file(origin=dataset_url, extract=True)

data_dir = pathlib.Path(data_dir)
print(f"{len(list(data_dir.glob('*/*.jpg')))} images in dataset")

batch_size = 20
resolution_down_factor = 0.2
patch_size = 224
image_size = 672

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  labels=None,
  crop_to_aspect_ratio=True,
  shuffle=True,
  image_size=(image_size, image_size),
  batch_size=batch_size)

# reduce the dataset size for faster processing
ds = ds.take(int(0.001 * len(ds)))

print(f'dataset size: {len(ds)}')

In [None]:
reduce_resolution = tf.keras.Sequential([
  layers.Resizing(int(image_size*resolution_down_factor), int(image_size*resolution_down_factor)),
  layers.Resizing(image_size, image_size),])

scale = layers.Rescaling(1./255)

def transforms(input):
  y = scale(input)
  X = tf.cast(reduce_resolution(input), tf.uint8) 
  return (X, y)


AUTOTUNE = tf.data.AUTOTUNE

test_ds = ds.map(transforms, num_parallel_calls=AUTOTUNE).prefetch(buffer_size=AUTOTUNE)

In [None]:
# Grab a batch of images
X_batch, y_batch = next(test_ds.as_numpy_iterator())

In [None]:
def process(image, model, save_path=None):
  crops = []
  for h in range(3):
    for w in range(3):
      crops.append(tf.image.crop_to_bounding_box(image, h*224, w*224, 224, 224))

  crops = model.predict(np.array(crops))
  crops = [np.squeeze(np.clip(c, 0.0, 1.0)) for c in crops]

  top = tf.concat(crops[:3], axis=1)
  middle = tf.concat(crops[3:6], axis=1)
  bottom = tf.concat(crops[6:], axis=1)
  whole = tf.concat((top, middle, bottom), axis=0)

  if save_path is not None:
    plt.imsave(save_path, whole.numpy())

  return whole

In [None]:
# Show example images
fig = plt.figure(figsize=(4, 2), dpi=300)
columns = 4
rows = 2
plt.title('input vs target (training)', fontsize=4)
plt.axis('off')
for i in range(1, columns*rows+1, 2):
    input = X_batch[i]
    target = y_batch[i]
    fig.add_subplot(rows, columns, i)
    plt.axis('off')
    plt.imshow(input)
    fig.add_subplot(rows, columns, i+1)
    plt.axis('off')
    plt.imshow(target)
plt.show()

In [None]:
rand_batch_int = int(random() * batch_size)

In [None]:
save_path = './result_imgs/a'
os.makedirs(save_path, exist_ok=True)

input = X_batch[rand_batch_int]
target = y_batch[rand_batch_int]

plt.imsave(f'{save_path}/0.jpg', input)

model_name = 'input_224_4_blocks_mobile_perceptual_loss_0'
model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
sr_image1 = process(X_batch[rand_batch_int], model, f'{save_path}/1.jpg')

model_name = 'input_224_4_blocks_mobile_perceptual_loss_0-01'
model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
sr_image2 = process(X_batch[rand_batch_int], model, f'{save_path}/2.jpg')

model_name = 'input_224_4_blocks_mobile_perceptual_loss_0-1'
model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
sr_image3 = process(X_batch[rand_batch_int], model, f'{save_path}/3.jpg')

model_name = 'input_224_4_blocks_mobile_perceptual_loss_1'
model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
sr_image4 = process(X_batch[rand_batch_int], model, f'{save_path}/4.jpg')

# Display results
fig = plt.figure(figsize=(2, 3), dpi=600)
columns, rows = (2, 3)

fig.add_subplot(rows, columns, 1)
plt.title('Input Image', fontsize=4)
plt.axis('off')
plt.imshow(input)

fig.add_subplot(rows, columns, 2)
plt.title('Ground Truth', fontsize=4)
plt.axis('off')
plt.imshow(target)

fig.add_subplot(rows, columns, 3)
plt.title('Pixel-wise MSE', fontsize=4)
plt.axis('off')
plt.imshow(sr_image1)

fig.add_subplot(rows, columns, 4)
plt.title('0.01x Perceptual Loss', fontsize=4)
plt.axis('off')
plt.imshow(sr_image2)

fig.add_subplot(rows, columns, 5)
plt.title('0.1x Perceptual Loss', fontsize=4)
plt.axis('off')
plt.imshow(sr_image3)

fig.add_subplot(rows, columns, 6)
plt.title('1x Perceptual Loss', fontsize=4)
plt.axis('off')
plt.imshow(sr_image4)

plt.show()


In [None]:
save_path = './result_imgs/b'
os.makedirs(save_path, exist_ok=True)

input = X_batch[rand_batch_int]
target = y_batch[rand_batch_int]

plt.imsave(f'{save_path}/0.jpg', input)

model_name = 'input_224_4_blocks_mobile_perceptual_loss_1'
model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
sr_image1 = process(X_batch[rand_batch_int], model, f'{save_path}/1.jpg')

# model_name = 'input_224_4_blocks_vgg_perceptual_loss_1'
# model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
# sr_image2 = process(X_batch[rand_batch_int], model, f'{save_path}/4.jpg')

model_name = 'input_224_6_blocks_mobile_perceptual_loss_1'
model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
sr_image3 = process(X_batch[rand_batch_int], model, f'{save_path}/2.jpg')

# model_name = 'input_224_6_blocks_vgg_perceptual_loss_1'
# model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
# sr_image4 = process(X_batch[rand_batch_int], model, f'{save_path}/3.jpg')


# Display results
fig = plt.figure(figsize=(2, 5), dpi=600)
columns, rows = (2, 5)

fig.add_subplot(rows, columns, 1)
plt.title('4 block mobile', fontsize=4)
plt.axis('off')
plt.imshow(sr_image1)

fig.add_subplot(rows, columns, 2)
plt.title('Ground Truth', fontsize=4)
plt.axis('off')
plt.imshow(target)

# fig.add_subplot(rows, columns, 3)
# plt.title('4 block vgg', fontsize=4)
# plt.axis('off')
# plt.imshow(sr_image2)

# fig.add_subplot(rows, columns, 4)
# plt.title('Ground Truth', fontsize=4)
# plt.axis('off')
# plt.imshow(target)

fig.add_subplot(rows, columns, 5)
plt.title('6 block mobile', fontsize=4)
plt.axis('off')
plt.imshow(sr_image3)

fig.add_subplot(rows, columns, 6)
plt.title('Ground Truth', fontsize=4)
plt.axis('off')
plt.imshow(target)

# fig.add_subplot(rows, columns, 7)
# plt.title('6 block vgg', fontsize=4)
# plt.axis('off')
# plt.imshow(sr_image4)

# fig.add_subplot(rows, columns, 8)
# plt.title('Ground Truth', fontsize=4)
# plt.axis('off')
# plt.imshow(target)

fig.add_subplot(rows, columns, 9)
plt.title('Input Image', fontsize=4)
plt.axis('off')
plt.imshow(input)


plt.show()


In [None]:
# def enhance(image, scale, model, save_path):
#   # blow up by scale
#   # convolve model acros image
#   # save image
#   pass