In [2]:
import os
import time
import ntpath
import numpy as np
from PIL import Image
from os.path import join, exists
from keras.models import model_from_json
from sklearn.metrics import mean_squared_error
import math  # Import math for log10

## local libs
from utils.data_utils import getPaths, read_and_resize, preprocess, deprocess

## for testing arbitrary local data

data_dir = "../data/test/A/"
from utils.data_utils import get_local_test_data
test_paths = getPaths(data_dir)

# Remove duplicates from the test_paths list
test_paths = list(set(test_paths))

print("{0} test images are loaded".format(len(test_paths)))


## create dir for log and (sampled) validation data
samples_dir = "../data/output/"
if not exists(samples_dir): os.makedirs(samples_dir)

## test funie-gan
checkpoint_dir  = 'models/gen_p/'
model_name_by_epoch = "model_15320_" 
## test funie-gan-up
#checkpoint_dir  = 'models/gen_up/'
#model_name_by_epoch = "model_35442_" 

model_h5 = checkpoint_dir + model_name_by_epoch + ".h5"  
model_json = checkpoint_dir + model_name_by_epoch + ".json"
# sanity
assert (exists(model_h5) and exists(model_json))

# load model
with open(model_json, "r") as json_file:
    loaded_model_json = json_file.read()
funie_gan_generator = model_from_json(loaded_model_json)
# load weights into the new model
funie_gan_generator.load_weights(model_h5)
print("\nLoaded data and model")

# Initialize lists to store MSE and PSNR values for all images
all_mse_values = []
all_psnr_values = []

# Testing loop
times = []
for img_path in test_paths:
    # Prepare data
    inp_img = read_and_resize(img_path, (256, 256))
    im = preprocess(inp_img)
    im = np.expand_dims(im, axis=0)  # (1,256,256,3)
    
    # Generate enhanced image
    s = time.time()
    gen = funie_gan_generator.predict(im)
    gen_img = deprocess(gen)[0]
    tot = time.time() - s
    times.append(tot)
    
    # Calculate MSE between the original image and enhanced image
    mse = mean_squared_error(inp_img.flatten(), gen_img.flatten())
    all_mse_values.append(mse)
    
    # Calculate PSNR using the calculated MSE
    psnr = 10 * math.log10((255 ** 2) / mse)
    all_psnr_values.append(psnr)
    
    # Save output images
    img_name = ntpath.basename(img_path)
    out_img = np.hstack((inp_img, gen_img)).astype('uint8')
    Image.fromarray(out_img).save(join(samples_dir, img_name))

# Some statistics
num_test = len(test_paths)
if num_test == 0:
    print("\nFound no images for test")
else:
    print("\nTotal images: {0}".format(num_test))
    Ttime, Mtime = np.sum(times[1:]), np.mean(times[1:])
    print("Time taken: {0} sec at {1} fps".format(Ttime, 1. / Mtime))
    print("\nSaved generated images in {0}\n".format(samples_dir))

# Print or store the MSE and PSNR values for each image
for i, img_path in enumerate(test_paths):
    img_name = ntpath.basename(img_path)
    print(f"Image {i+1}: {img_name}, MSE: {all_mse_values[i]}, PSNR: {all_psnr_values[i]}")


22 test images are loaded

Loaded data and model

Total images: 22
Time taken: 2.4248616695404053 sec at 8.66028782746202 fps

Saved generated images in ../data/output/

Image 1: frame4_0001.png, MSE: 177.99993896484375, PSNR: 25.626605074760732
Image 2: frame3_0001.png, MSE: 2726.20654296875, PSNR: 13.775216050869473
Image 3: 1.png, MSE: 534.1290283203125, PSNR: 20.854341796470365
Image 4: frame014_0001.png, MSE: 92.98616790771484, PSNR: 28.446620106809533
Image 5: frame016_0001.png, MSE: 2173.0048828125, PSNR: 14.760196586724927
Image 6: frame011_0001.png.png, MSE: 2946.330322265625, PSNR: 13.437989255251825
Image 7: frame025_0001.png, MSE: 2957.3095703125, PSNR: 13.4218357214111
Image 8: frame5_0001.png, MSE: 270.3298645019531, PSNR: 23.811863341408777
Image 9: frame019_0001.png, MSE: 1966.9864501953125, PSNR: 15.192789926240566
Image 10: frame9_0001.png, MSE: 213.0283203125, PSNR: 24.846430178148033
Image 11: frame_0001.png, MSE: 825.7954711914062, PSNR: 18.962078640705386
Image 12