In [1]:
import os
import numpy as np

from PIL import Image
from skimage.metrics import structural_similarity as skimage_ssim
from skimage.metrics import peak_signal_noise_ratio as skimage_psnr


In [2]:
def compare_results(real_image, torch_paper_image, torch_mine_image, jax_image):
    real_image = Image.open(real_image).convert("RGB")
    torch_paper_image = Image.open(torch_paper_image).convert("RGB")
    torch_mine_image = Image.open(torch_mine_image).convert("RGB")
    jax_image = Image.open(jax_image).convert("RGB")

    real_image = np.array(real_image).astype(np.uint8)
    torch_paper_image = np.array(torch_paper_image).astype(np.uint8)
    torch_mine_image = np.array(torch_mine_image).astype(np.uint8)
    jax_image = np.array(jax_image).astype(np.uint8)

    torch_ssim = skimage_ssim(real_image, torch_paper_image, channel_axis=2)
    torch_mine_ssim = skimage_ssim(real_image, torch_mine_image, channel_axis=2)
    jax_ssim = skimage_ssim(real_image, jax_image, channel_axis=2)

    torch_psnr = skimage_psnr(real_image, torch_paper_image)
    torch_mine_psnr = skimage_psnr(real_image, torch_mine_image)
    jax_psnr = skimage_psnr(real_image, jax_image)

    return (torch_ssim, torch_mine_ssim, jax_ssim), (torch_psnr, torch_mine_psnr, jax_psnr)

In [3]:
dataset_path = "../datasets/test/LOL/high/"
torch_mine_pretrained = "../output/torch/mine/pretrained_results/"
torch_their_pretrained = "../output/torch/their/pretrained_results/"
jax_pretrained = "../output/jax/pretrained_results/"
torch_mine_finetuned = "../output/torch/mine/finetune_results/"
torch_their_finetuned = "../output/torch/their/finetune_results/"
jax_finetuned = "../output/jax/finetune_results/"

In [4]:
torch_mine_pretrained_images = os.listdir(torch_mine_pretrained)
torch_their_pretrained_images = os.listdir(torch_their_pretrained)
jax_pretrained_images = os.listdir(jax_pretrained)
torch_mine_finetuned_images = os.listdir(torch_mine_finetuned)
torch_their_finetuned_images = os.listdir(torch_their_finetuned)
jax_finetuned_images = os.listdir(jax_finetuned)

In [5]:
print("Pretrained model results")
torch_their_ssim = 0
torch_mine_ssim = 0
jax_ssim = 0
torch_their_psnr = 0
torch_mine_psnr = 0
jax_psnr = 0
for image in torch_their_pretrained_images:
    torch_their_image = os.path.join(torch_their_pretrained, image)
    torch_mine_image = os.path.join(torch_mine_pretrained, image)
    jax_image = os.path.join(jax_pretrained, image)
    real_image = os.path.join(dataset_path, image)
    ssim, psnr = compare_results(
        real_image, torch_their_image, torch_mine_image, jax_image
    )
    torch_their_ssim += ssim[0]
    torch_mine_ssim += ssim[1]
    jax_ssim += ssim[2]
    torch_their_psnr += psnr[0]
    torch_mine_psnr += psnr[1]
    jax_psnr += psnr[2]

torch_their_ssim /= len(torch_their_pretrained_images)
torch_mine_ssim /= len(torch_mine_pretrained_images)
jax_ssim /= len(jax_pretrained_images)
torch_their_psnr /= len(torch_their_pretrained_images)
torch_mine_psnr /= len(torch_mine_pretrained_images)
jax_psnr /= len(jax_pretrained_images)
print(
    f"SSIM: Torch Original: {torch_their_ssim}, Reproduced: {torch_mine_ssim}, Jax: {jax_ssim}"
)
print(
    f"PSNR: Torch Original: {torch_their_psnr}, Reproduced: {torch_mine_psnr}, Jax: {jax_psnr}"
)

Pretrained model results
SSIM: Torch Original: 0.6972939948490605, Reproduced: 0.6753012227723739, Jax: 0.677763565747528
PSNR: Torch Original: 19.258117811144565, Reproduced: 16.925361788953147, Jax: 18.333172860999262


In [6]:
print("Finetuned model results")
torch_their_ssim = 0
torch_mine_ssim = 0
jax_ssim = 0
torch_their_psnr = 0
torch_mine_psnr = 0
jax_psnr = 0
for image in torch_their_finetuned_images:
    torch_their_image = os.path.join(torch_their_finetuned, image)
    torch_mine_image = os.path.join(torch_mine_finetuned, image)
    jax_image = os.path.join(jax_finetuned, image)
    real_image = os.path.join(dataset_path, image)
    ssim, psnr = compare_results(
        real_image, torch_their_image, torch_mine_image, jax_image
    )
    torch_their_ssim += ssim[0]
    torch_mine_ssim += ssim[1]
    jax_ssim += ssim[2]
    torch_their_psnr += psnr[0]
    torch_mine_psnr += psnr[1]
    jax_psnr += psnr[2]

torch_their_ssim /= len(torch_their_finetuned_images)
torch_mine_ssim /= len(torch_mine_finetuned_images)
jax_ssim /= len(jax_finetuned_images)
torch_their_psnr /= len(torch_their_finetuned_images)
torch_mine_psnr /= len(torch_mine_finetuned_images)
jax_psnr /= len(jax_finetuned_images)
print(
    f"SSIM: Torch Original: {torch_their_ssim}, Reproduced: {torch_mine_ssim}, Jax: {jax_ssim}"
)
print(
    f"PSNR: Torch Original: {torch_their_psnr}, Reproduced: {torch_mine_psnr}, Jax: {jax_psnr}"
)

Finetuned model results
SSIM: Torch Original: 0.846008663771394, Reproduced: 0.7959400883183082, Jax: 0.8086479745572868
PSNR: Torch Original: 21.94250476459355, Reproduced: 19.533023404713106, Jax: 20.456905417044503
