In [1]:
import os
import numpy as np

from PIL import Image
from skimage.metrics import structural_similarity as skimage_ssim
from skimage.metrics import peak_signal_noise_ratio as skimage_psnr


In [2]:
def compare_results(real_image, torch_image, jax_image):
    real_image = Image.open(real_image).convert("RGB")
    torch_image = Image.open(torch_image).convert("RGB")
    jax_image = Image.open(jax_image).convert("RGB")

    real_image = np.array(real_image).astype(np.uint8)
    torch_image = np.array(torch_image).astype(np.uint8)
    jax_image = np.array(jax_image).astype(np.uint8)

    torch_ssim = skimage_ssim(real_image, torch_image, channel_axis=2)
    jax_ssim = skimage_ssim(real_image, jax_image, channel_axis=2)

    torch_psnr = skimage_psnr(real_image, torch_image)
    jax_psnr = skimage_psnr(real_image, jax_image)

    return (torch_ssim, jax_ssim), (torch_psnr, jax_psnr)

In [3]:
dataset_path = "../datasets/test/LOL/high/"
torch_pretrained = "../output/torch/pretrained_results/"
jax_pretrained = "../output/jax/pretrained_results/"
torch_finetuned = "../output/torch/finetune_results/"
jax_finetuned = "../output/jax/finetune_results/"

In [4]:
torch_pretrained_images = os.listdir(torch_pretrained)
jax_pretrained_images = os.listdir(jax_pretrained)
torch_finetuned_images = os.listdir(torch_finetuned)
jax_finetuned_images = os.listdir(jax_finetuned)

In [5]:
print("Pretrained model results")
torch_ssim = 0
jax_ssim = 0
torch_psnr = 0
jax_psnr = 0
for image in torch_pretrained_images:
    torch_image = os.path.join(torch_pretrained, image)
    jax_image = os.path.join(jax_pretrained, image)
    real_image = os.path.join(dataset_path, image)
    ssim, psnr = compare_results(real_image, torch_image, jax_image)
    torch_ssim += ssim[0]
    jax_ssim += ssim[1]
    torch_psnr += psnr[0]
    jax_psnr += psnr[1]

torch_ssim /= len(torch_pretrained_images)
jax_ssim /= len(jax_pretrained_images)
torch_psnr /= len(torch_pretrained_images)
jax_psnr /= len(jax_pretrained_images)
print(f"SSIM: Torch: {torch_ssim}, Jax: {jax_ssim}")
print(f"PSNR: Torch: {torch_psnr}, Jax: {jax_psnr}")

Pretrained model results
SSIM: Torch: 0.6972939948490605, Jax: 0.677763565747528
PSNR: Torch: 19.258117811144565, Jax: 18.333172860999262


In [6]:
print("Finetuned model results")
torch_ssim = 0
jax_ssim = 0
torch_psnr = 0
jax_psnr = 0
for image in torch_finetuned_images:
    torch_image = os.path.join(torch_finetuned, image)
    jax_image = os.path.join(jax_finetuned, image)
    real_image = os.path.join(dataset_path, image)
    ssim, psnr = compare_results(real_image, torch_image, jax_image)
    torch_ssim += ssim[0]
    jax_ssim += ssim[1]
    torch_psnr += psnr[0]
    jax_psnr += psnr[1]

torch_ssim /= len(torch_finetuned_images)
jax_ssim /= len(jax_finetuned_images)
torch_psnr /= len(torch_finetuned_images)
jax_psnr /= len(jax_finetuned_images)
print(f"SSIM: Torch: {torch_ssim}, Jax: {jax_ssim}")
print(f"PSNR: Torch: {torch_psnr}, Jax: {jax_psnr}")

Finetuned model results
SSIM: Torch: 0.846008663771394, Jax: 0.8086479745572868
PSNR: Torch: 21.94250476459355, Jax: 20.456905417044503
