In [None]:
import os
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from patchify import (patchify, unpatchify)
from utils import center_crop, reconnect_patches

In [None]:
# Display a list with the available models and ask the user to choose which to use
models_path = 'models'
models_list = os.listdir(models_path)

# create a dictionary with numbers as keys and the model names as values
models2dict = dict([(str(a), b) for a, b in enumerate(models_list)])

print('List of Models', '--------------', sep='\n')
for item in models2dict.items():
    print(*item, sep=' --> ')

print('\nChoose a model from the list by typing the number of its key + Enter:')
model_name = models2dict[input()]
print(f'\n{model_name} was selected.')

In [None]:
# Load the model
model = keras.models.load_model(os.path.join(models_path, model_name))

In [None]:
# Retrieve the input layer and extract the size of width of the images (which is equal to the height)
input_layer = model.get_layer(index=0)
input_size = input_layer.input_shape[0][1]

In [None]:
# Load the input and ground truth images 
x_initial_valid = plt.imread('validation/201-INPUT.jpg')
y_initial_valid = plt.imread('validation/201-OUTPUT-GT.png')

CROP_SIZE_W = 2560 
CROP_SIZE_H = 2560

x_initial_valid = center_crop(x_initial_valid, (CROP_SIZE_H, CROP_SIZE_W))
y_initial_valid = center_crop(y_initial_valid, (CROP_SIZE_H, CROP_SIZE_W))

In [None]:
print(x_initial_valid.shape)
print(y_initial_valid.shape)

In [None]:
# Create patches from the input and ground truth images
STEP = 256 
input_patches = np.squeeze(patchify(x_initial_valid, (input_size, input_size, 3), step=STEP))
ground_truth_patches = np.squeeze(patchify(y_initial_valid, (input_size, input_size), step=STEP))

In [None]:
predictions = np.squeeze(model.predict(np.reshape(input_patches, (-1, input_size, input_size, 3))))
predictions = np.reshape(predictions, ground_truth_patches.shape)

In [None]:
# TODO: To delete in the future. I can make this check in the "reconnect_patches" function in the utils.py
overlapped_images = 2 * input_size / STEP

assert overlapped_images == int(overlapped_images)
overlapped_images = int(overlapped_images)
assert overlapped_images == 4

In [None]:
unified_shape = y_initial_valid.shape + (overlapped_images, )
unified_predictions = reconnect_patches(predictions, unified_shape, STEP, input_size)

In [None]:
fig, ax = plt.subplots(2, 2, dpi=300)
i = 0
for row in range(2):
    for col in range(2):
        ax[row, col].imshow(unified_predictions[:, :, i])
        i += 1
plt.show()

In [None]:
new_pred = unified_predictions[STEP:-STEP, STEP:-STEP]

pixelwise_std = np.std(new_pred, axis=2)

In [None]:
np.mean(pixelwise_std)

In [None]:
plt.figure(dpi=300)
plt.imshow(1-pixelwise_std, cmap='Greys_r')
plt.show()