In [1]:
import os
import cv2
import numpy as np

from PIL import Image
from skimage.metrics import structural_similarity as skimage_ssim
from skimage.metrics import peak_signal_noise_ratio as skimage_psnr

In [2]:
def angular_error(image_A, image_B):
    """
    Fonte: ChatGPT (https://chat.openai.com/share/866185a9-3fe7-4876-b0eb-1c14d37f5bbf)
    O algoritmo foi revisado e faz completo sentido segundo o slide (https://drive.google.com/file/d/1CosA4BcKcUIuwFSlMvSt2RD6edWIyzH5/view?usp=sharing)

    Calculate angular error between corresponding pixels in two images.

    Parameters:
        image_A: array_like
            First image (BGR or RGB format).
        image_B: array_like
            Second image (BGR or RGB format).

    Returns:
        angular_error: float
            Mean angular error between the two images.
    """
    # Convert images to XYZ color space
    XYZ_A = cv2.cvtColor(image_A, cv2.COLOR_BGR2XYZ) / 255.0
    XYZ_B = cv2.cvtColor(image_B, cv2.COLOR_BGR2XYZ) / 255.0

    # Compute norms
    norm_A = np.linalg.norm(XYZ_A, axis=-1)
    norm_B = np.linalg.norm(XYZ_B, axis=-1)
    
    # Check for zeros in norms
    zero_norms_mask = (norm_A == 0) | (norm_B == 0)

    # Compute dot product
    dot_product = np.sum(XYZ_A * XYZ_B, axis=-1)

    # Compute angular error
    cos_theta = np.clip(dot_product / (norm_A * norm_B), -1, 1)  # Clip to handle numerical errors
    error = np.arccos(cos_theta)

    # Set angular error to zero where division by zero occurred
    error[zero_norms_mask] = 0.0

    angular_errors = np.degrees(error)

    # Compute mean angular error
    mean_angular_error = np.mean(angular_errors)

    return mean_angular_error

In [3]:
def compare_results(real_image, torch_paper_image, torch_mine_image, jax_image):
    real_image = Image.open(real_image).convert("RGB")
    torch_paper_image = Image.open(torch_paper_image).convert("RGB")
    torch_mine_image = Image.open(torch_mine_image).convert("RGB")
    jax_image = Image.open(jax_image).convert("RGB")

    real_image = np.array(real_image).astype(np.uint8)
    torch_paper_image = np.array(torch_paper_image).astype(np.uint8)
    torch_mine_image = np.array(torch_mine_image).astype(np.uint8)
    jax_image = np.array(jax_image).astype(np.uint8)

    torch_ssim = skimage_ssim(real_image, torch_paper_image, channel_axis=2)
    torch_mine_ssim = skimage_ssim(real_image, torch_mine_image, channel_axis=2)
    jax_ssim = skimage_ssim(real_image, jax_image, channel_axis=2)

    torch_psnr = skimage_psnr(real_image, torch_paper_image)
    torch_mine_psnr = skimage_psnr(real_image, torch_mine_image)
    jax_psnr = skimage_psnr(real_image, jax_image)

    torch_angular_error = angular_error(real_image, torch_paper_image)
    torch_mine_angular_error = angular_error(real_image, torch_mine_image)
    jax_angular_error = angular_error(real_image, jax_image)

    return (
        (torch_ssim, torch_mine_ssim, jax_ssim),
        (torch_psnr, torch_mine_psnr, jax_psnr),
        (torch_angular_error, torch_mine_angular_error, jax_angular_error),
    )

# On synthetic ataset

In [4]:
dataset_path = "../datasets/test/to_report/high/"
torch_mine_finetuned = "../output/torch/mine/synthetic_results/"
torch_their_finetuned = "../output/torch/their/synthetic_results/"
jax_finetuned = "../output/jax/synthetic_results/"

In [5]:
torch_mine_finetuned_images = os.listdir(torch_mine_finetuned)
torch_their_finetuned_images = os.listdir(torch_their_finetuned)
jax_finetuned_images = os.listdir(jax_finetuned)

In [6]:
print("Pretrained model results")
torch_their_ssim = 0
torch_mine_ssim = 0
jax_ssim = 0

torch_their_psnr = 0
torch_mine_psnr = 0
jax_psnr = 0

torch_mine_angular_error = 0
torch_their_angular_error = 0
jax_angular_error = 0

for image in torch_their_finetuned_images:
    torch_their_image = os.path.join(torch_their_finetuned, image)
    torch_mine_image = os.path.join(torch_mine_finetuned, image)
    jax_image = os.path.join(jax_finetuned, image)
    real_image = os.path.join(dataset_path, image)
    ssim, psnr, ang_err = compare_results(
        real_image, torch_their_image, torch_mine_image, jax_image
    )
    torch_their_ssim += ssim[0]
    torch_mine_ssim += ssim[1]
    jax_ssim += ssim[2]

    torch_their_psnr += psnr[0]
    torch_mine_psnr += psnr[1]
    jax_psnr += psnr[2]

    torch_their_angular_error += ang_err[0]
    torch_mine_angular_error += ang_err[1]
    jax_angular_error += ang_err[2]

torch_their_ssim /= len(torch_their_finetuned_images)
torch_mine_ssim /= len(torch_mine_finetuned_images)
jax_ssim /= len(jax_finetuned_images)

torch_their_psnr /= len(torch_their_finetuned_images)
torch_mine_psnr /= len(torch_mine_finetuned_images)
jax_psnr /= len(jax_finetuned_images)

torch_their_angular_error /= len(torch_their_finetuned_images)
torch_mine_angular_error /= len(torch_mine_finetuned_images)
jax_angular_error /= len(jax_finetuned_images)

print(
    f"SSIM: Torch Original: {torch_their_ssim}, Reproduced: {torch_mine_ssim}, Jax: {jax_ssim}"
)
print(
    f"PSNR: Torch Original: {torch_their_psnr}, Reproduced: {torch_mine_psnr}, Jax: {jax_psnr}"
)
print(
    f"Angular Error: Torch Original: {torch_their_angular_error}, Reproduced: {torch_mine_angular_error}, Jax: {jax_angular_error}"
)

Pretrained model results


  cos_theta = np.clip(dot_product / (norm_A * norm_B), -1, 1)  # Clip to handle numerical errors


SSIM: Torch Original: 0.8993351620447376, Reproduced: 0.8445246565261463, Jax: 0.9054582661351688
PSNR: Torch Original: 25.19262867969908, Reproduced: 21.273988825866766, Jax: 27.77077605088171
Angular Error: Torch Original: 2.52325156597444, Reproduced: 3.3811651382092585, Jax: 3.2380003984086487


# On LOL dataset

In [7]:
dataset_path = "../datasets/test/LOL/high/"
torch_mine_finetuned = "../output/torch/mine/finetune_results/"
torch_their_finetuned = "../output/torch/their/finetune_results/"
jax_finetuned = "../output/jax/finetune_results/"

In [8]:

torch_mine_finetuned_images = os.listdir(torch_mine_finetuned)
torch_their_finetuned_images = os.listdir(torch_their_finetuned)
jax_finetuned_images = os.listdir(jax_finetuned)

In [9]:
print("Finetuned model results")
torch_their_ssim = 0
torch_mine_ssim = 0
jax_ssim = 0

torch_their_psnr = 0
torch_mine_psnr = 0
jax_psnr = 0

torch_mine_angular_error = 0
torch_their_angular_error = 0
jax_angular_error = 0

for image in torch_their_finetuned_images:
    torch_their_image = os.path.join(torch_their_finetuned, image)
    torch_mine_image = os.path.join(torch_mine_finetuned, image)
    jax_image = os.path.join(jax_finetuned, image)
    real_image = os.path.join(dataset_path, image)
    ssim, psnr, ang_err = compare_results(
        real_image, torch_their_image, torch_mine_image, jax_image
    )
    torch_their_ssim += ssim[0]
    torch_mine_ssim += ssim[1]
    jax_ssim += ssim[2]

    torch_their_psnr += psnr[0]
    torch_mine_psnr += psnr[1]
    jax_psnr += psnr[2]

    torch_their_angular_error += ang_err[0]
    torch_mine_angular_error += ang_err[1]
    jax_angular_error += ang_err[2]

torch_their_ssim /= len(torch_their_finetuned_images)
torch_mine_ssim /= len(torch_mine_finetuned_images)
jax_ssim /= len(jax_finetuned_images)

torch_their_psnr /= len(torch_their_finetuned_images)
torch_mine_psnr /= len(torch_mine_finetuned_images)
jax_psnr /= len(jax_finetuned_images)

torch_their_angular_error /= len(torch_their_finetuned_images)
torch_mine_angular_error /= len(torch_mine_finetuned_images)
jax_angular_error /= len(jax_finetuned_images)

print(
    f"SSIM: Torch Original: {torch_their_ssim}, Reproduced: {torch_mine_ssim}, Jax: {jax_ssim}"
)
print(
    f"PSNR: Torch Original: {torch_their_psnr}, Reproduced: {torch_mine_psnr}, Jax: {jax_psnr}"
)
print(
    f"Angular Error: Torch Original: {torch_their_angular_error}, Reproduced: {torch_mine_angular_error}, Jax: {jax_angular_error}"
)

Finetuned model results


  cos_theta = np.clip(dot_product / (norm_A * norm_B), -1, 1)  # Clip to handle numerical errors


SSIM: Torch Original: 0.846008663771394, Reproduced: 0.7959400883183082, Jax: 0.8086479745572868
PSNR: Torch Original: 21.94250476459355, Reproduced: 19.533023404713106, Jax: 20.456905417044503
Angular Error: Torch Original: 3.672147276143386, Reproduced: 4.757357429316483, Jax: 4.908904760874908
