# PA228 Project - machine learning in image processing

Author: Petr Kadlec, UČO: 485208

## Loading the dataset:

Some help 'cause I need it: https://github.com/krasserm/super-resolution

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

%matplotlib inline

In [None]:
#%load_ext nb_mypy

In [None]:
import tensorflow as tf

In [None]:
gpus = tf.config.list_physical_devices('GPU')
print(f'Detected gpus: {gpus}')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
def ishow(img,
          cmap='viridis',
          title='',
          fig_size=(8,6),
          colorbar=True,
          interpolation='none'):
    ' Function `ishow` displays an image in a new window. '
    
    extent = (0, img.shape[1], img.shape[0], 0)
    fig, ax = plt.subplots(figsize=fig_size)
    pcm = ax.imshow(img,
              extent=extent,
              cmap=cmap,
              interpolation=interpolation)
    
    ax.set_frame_on(False)
    plt.title(title)
    plt.tight_layout()
    if colorbar:
        
        fig.colorbar(pcm, orientation='vertical')
    plt.show()

In [None]:
dataset_location: str = "./../dataset/"

training_prefix = dataset_location + "DIV2K_train_"
validation_prefix = dataset_location + "DIV2K_valid_"

original_train = training_prefix + "HR"
original_test = validation_prefix + "HR"

set_difficult: list[str] = [training_prefix + "LR_difficult", original_train, validation_prefix + "LR_difficult", original_test]
set_mild: list[str] = [training_prefix + "LR_mild", original_train, validation_prefix + "LR_mild", original_test]
set_wild: list[str] = [training_prefix + "LR_wild", original_train, validation_prefix + "LR_wild", original_test]
set_x8: list[str] = [training_prefix + "LR_x8", original_train, validation_prefix + "LR_x8", original_test]

In [None]:
crop_size = 512
upscale_factor = 4
input_size = crop_size // upscale_factor
batch_size = 8

In [None]:
current_set: list[str] = set_mild

In [None]:
test_orig_ds = tf.keras.utils.image_dataset_from_directory(current_set[3],
                                                 labels=None,
                                                 label_mode="categorical",
                                                 image_size=(crop_size, crop_size),
                                                 batch_size=batch_size,
                                                 interpolation="nearest",
                                                 seed=1,
                                                 )

test_mod_ds = tf.keras.utils.image_dataset_from_directory(current_set[2],
                                                 labels=None,
                                                 label_mode="categorical",
                                                 image_size=(crop_size, crop_size),
                                                 batch_size=batch_size,
                                                 interpolation="nearest",
                                                 seed=1,
                                                 )

train_orig_ds = tf.keras.utils.image_dataset_from_directory(current_set[1],
                                                 labels=None,
                                                 label_mode="categorical",
                                                 image_size=(crop_size, crop_size),
                                                 batch_size=batch_size,
                                                 interpolation="nearest",
                                                 seed=1,
                                                 )

train_mod_ds = tf.keras.utils.image_dataset_from_directory(current_set[0],
                                                 labels=None,
                                                 label_mode="categorical",
                                                 image_size=(crop_size, crop_size),
                                                 batch_size=batch_size,
                                                 interpolation="nearest",
                                                 seed=1,
                                                 )


In [None]:
import os

def get_images_in_dir(dir_name: str) -> list[str]:
    return sorted(
        [
            os.path.join(dir_name + "/", fname)
            for fname in os.listdir(dir_name + "/")
            if fname.endswith(".png")
        ]
    )

In [None]:
list(test_mod_ds)[0][0]

In [None]:
ishow(list(test_mod_ds)[0][0])

Rescale all the datasets:

In [None]:
def scaling(input_image):
    input_image = input_image / 255.0
    return input_image

#tmp = ds_test_img.map(lambda x: tf.cast(x, tf.float32))
#ds_test_img = tmp.map(scaling)

test_mod_ds = test_mod_ds.map(lambda x: tf.cast(x, tf.float32)).map(scaling)
test_orig_ds = test_orig_ds.map(lambda x: tf.cast(x, tf.float32)).map(scaling)
train_mod_ds = train_mod_ds.map(lambda x: tf.cast(x, tf.float32)).map(scaling)
train_orig_ds = train_orig_ds.map(lambda x: tf.cast(x, tf.float32)).map(scaling)

In [None]:
list(test_orig_ds)[0][0]

In [None]:
ishow(list(test_orig_ds)[0][0])

## Crop and resize images

In [None]:
# Use TF Ops to process.
def process_input(input, input_size, upscale_factor):
    input = tf.image.rgb_to_yuv(input)
    last_dimension_axis = len(input.shape) - 1
    y, u, v = tf.split(input, 3, axis=last_dimension_axis)
    return tf.image.resize(y, [input_size, input_size], method="area")


def process_target(input):
    input = tf.image.rgb_to_yuv(input)
    last_dimension_axis = len(input.shape) - 1
    y, u, v = tf.split(input, 3, axis=last_dimension_axis)
    return y

In [None]:
# here be dragons => here is the place to rewrite it

train_mod_ds = train_mod_ds.map(
    lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
train_mod_ds = train_mod_ds.prefetch(buffer_size=32)

train_orig_ds = train_orig_ds.map(
    lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
train_orig_ds = train_orig_ds.prefetch(buffer_size=32)

test_mod_ds = test_mod_ds.map(
    lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
test_mod_ds = test_mod_ds.prefetch(buffer_size=32)

test_orig_ds = test_orig_ds.map(
    lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
test_orig_ds = test_orig_ds.prefetch(buffer_size=32)

In [None]:
from tensorflow.keras.preprocessing.image import array_to_img
from IPython.display import display

for batch in train_mod_ds.take(1):
    for img in batch[0]:
        display(array_to_img(img))
    for img in batch[1]:
        display(array_to_img(img))


In [None]:
def get_model(upscale_factor=3, channels=1):
    conv_args = {
        "activation": "relu",
        "kernel_initializer": "Orthogonal",
        "padding": "same",
    }
    inputs = tf.keras.Input(shape=(None, None, channels))
    x = tf.keras.layers.Conv2D(64, 5, **conv_args)(inputs)
    x = tf.keras.layers.Conv2D(64, 3, **conv_args)(x)
    x = tf.keras.layers.Conv2D(32, 3, **conv_args)(x)
    x = tf.keras.layers.Conv2D(channels * (upscale_factor ** 2), 3, **conv_args)(x)
    outputs = tf.nn.depth_to_space(x, upscale_factor)

    return tf.keras.Model(inputs, outputs)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from tensorflow.keras.preprocessing.image import img_to_array
import PIL


def plot_results(img, prefix, title):
    """Plot the result with zoom-in area."""
    img_array = img_to_array(img)
    img_array = img_array.astype("float32") / 255.0

    # Create a new figure with a default 111 subplot.
    fig, ax = plt.subplots()
    im = ax.imshow(img_array[::-1], origin="lower")

    plt.title(title)
    # zoom-factor: 2.0, location: upper-left
    axins = zoomed_inset_axes(ax, 2, loc=2)
    axins.imshow(img_array[::-1], origin="lower")

    # Specify the limits.
    x1, x2, y1, y2 = 200, 300, 100, 200
    # Apply the x-limits.
    axins.set_xlim(x1, x2)
    # Apply the y-limits.
    axins.set_ylim(y1, y2)

    plt.yticks(visible=False)
    plt.xticks(visible=False)

    # Make the line.
    mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
    # plt.savefig(str(prefix) + "-" + title + ".png")
    plt.show()


def get_lowres_image(img, upscale_factor):
    """Return low-resolution image to use as model input."""
    return img.resize(
        (img.size[0] // upscale_factor, img.size[1] // upscale_factor),
        PIL.Image.Resampling.BICUBIC,
    )


def upscale_image(model, img):
    """Predict the result based on input image and restore the image as RGB."""
    ycbcr = img.convert("YCbCr")
    y, cb, cr = ycbcr.split()
    y = img_to_array(y)
    y = y.astype("float32") / 255.0

    input = np.expand_dims(y, axis=0)
    out = model.predict(input)

    out_img_y = out[0]
    out_img_y *= 255.0

    # Restore the image in RGB color space.
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = out_img_y.reshape((np.shape(out_img_y)[0], np.shape(out_img_y)[1]))
    out_img_y = PIL.Image.fromarray(np.uint8(out_img_y), mode="L")
    out_img_cb = cb.resize(out_img_y.size, PIL.Image.Resampling.BICUBIC)
    out_img_cr = cr.resize(out_img_y.size, PIL.Image.Resampling.BICUBIC)
    out_img = PIL.Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
        "RGB"
    )
    return out_img


In [None]:
modified_test_paths = get_images_in_dir(current_set[2])
modified_test_paths

In [None]:
from tensorflow.keras.preprocessing.image import load_img
import math

class ESPCNCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super(ESPCNCallback, self).__init__()
        #self.test_img = get_lowres_image(load_img(original_test + "/0801.png"), upscale_factor)
        self.test_img = load_img(modified_test_paths[0])

    # Store PSNR value in each epoch.
    def on_epoch_begin(self, epoch, logs=None):
        self.psnr = []

    def on_epoch_end(self, epoch, logs=None):
        print("Mean PSNR for epoch: %.2f" % (np.mean(self.psnr)))
        if epoch % 20 == 0:
            prediction = upscale_image(self.model, self.test_img)
            plot_results(prediction, "epoch-" + str(epoch), "prediction")

    def on_test_batch_end(self, batch, logs=None):
        self.psnr.append(10 * math.log10(1 / logs["loss"]))


In [None]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor="loss", patience=10)

checkpoint_filepath = "./tmp/checkpoint"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor="loss",
    mode="min",
    save_best_only=True,
)

model = get_model(upscale_factor=upscale_factor, channels=1)
model.summary()

callbacks = [ESPCNCallback(), early_stopping_callback, model_checkpoint_callback]
#callbacks = [early_stopping_callback, model_checkpoint_callback]

loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)


In [None]:
epochs = 2

model.compile(
    optimizer=optimizer, loss=loss_fn,
)

model.fit(
    train_orig_ds, epochs=epochs, callbacks=callbacks, validation_data=test_orig_ds, verbose=2
)

# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)


In [None]:
test_img_paths = get_images_in_dir(original_test)

In [None]:
test_img_paths

In [None]:
total_bicubic_psnr = 0.0
total_test_psnr = 0.0

for index, test_img_path in enumerate(test_img_paths[10:20]):
    img = load_img(test_img_path)
    lowres_input = get_lowres_image(img, upscale_factor)
    w = lowres_input.size[0] * upscale_factor
    h = lowres_input.size[1] * upscale_factor
    highres_img = img.resize((w, h))
    prediction = upscale_image(model, lowres_input)
    lowres_img = lowres_input.resize((w, h))
    lowres_img_arr = img_to_array(lowres_img)
    highres_img_arr = img_to_array(highres_img)
    predict_img_arr = img_to_array(prediction)
    bicubic_psnr = tf.image.psnr(lowres_img_arr, highres_img_arr, max_val=255)
    test_psnr = tf.image.psnr(predict_img_arr, highres_img_arr, max_val=255)

    total_bicubic_psnr += bicubic_psnr
    total_test_psnr += test_psnr

    print(
        "PSNR of low resolution image and high resolution image is %.4f" % bicubic_psnr
    )
    print("PSNR of predict and high resolution is %.4f" % test_psnr)
    plot_results(lowres_img, index, "lowres")
    plot_results(highres_img, index, "highres")
    plot_results(prediction, index, "prediction")

print("Avg. PSNR of lowres images is %.4f" % (total_bicubic_psnr / 10))
print("Avg. PSNR of reconstructions is %.4f" % (total_test_psnr / 10))
