In [1]:
import jax.numpy as jnp
import numpy as np
import torch

from skimage.metrics import structural_similarity as skimage_ssim
from skimage.metrics import peak_signal_noise_ratio as skimage_psnr
from dm_pix import ssim, psnr

from dln.data import get_Low_light_training_set

from dln.jax_tv import total_variation
from dln.utils import TVLoss

In [2]:
train_set = get_Low_light_training_set(
    upscale_factor=1, patch_size=128, data_augmentation=True
)

In [3]:
low_light, normal_light = train_set[0]
jaxed_low_light = jnp.transpose(jnp.array(low_light), (1, 2, 0))
jaxed_normal_light = jnp.transpose(jnp.array(normal_light), (1, 2, 0))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [4]:
skimage_ssim_index, _ = skimage_ssim(
    (np.array(jaxed_normal_light) * 255).astype(np.uint8),
    (np.array(jaxed_low_light) * 255).astype(np.uint8),
    full=True,
    channel_axis=2,
)
print("SSIM index (scikit-image):", skimage_ssim_index)

SSIM index (scikit-image): 0.0932764238579652


In [5]:
jax_ssim_index = ssim(jaxed_normal_light, jaxed_low_light)
print("SSIM index (JAX):", jax_ssim_index)

SSIM index (JAX): 0.09816425


In [6]:
skimage_psnr_index = skimage_psnr(
    (np.array(jaxed_normal_light) * 255).astype(np.uint8),
    (np.array(jaxed_low_light) * 255).astype(np.uint8),
)
print("PSNR index (scikit-image):", skimage_psnr_index)

PSNR index (scikit-image): 13.263555511491747


In [7]:
jax_psnr_index = psnr(jaxed_normal_light, jaxed_low_light)
print("PSNR index (JAX):", jax_psnr_index)

PSNR index (JAX): 13.263566


In [8]:
torch_low_light = torch.tensor(np.array(jaxed_low_light)).permute(2, 0, 1).unsqueeze(0)
torch_normal_light = (
    torch.tensor(np.array(jaxed_normal_light)).permute(2, 0, 1).unsqueeze(0)
)
original_tv = TVLoss()
print("Original TV Loss:", original_tv(torch_normal_light))

Original TV Loss: tensor(0.0097)


In [9]:
batched_low_light = jnp.expand_dims(jnp.array(jaxed_low_light), axis=0)
batched_normal_light = jnp.expand_dims(jnp.array(jaxed_normal_light), axis=0)
print("JAX TV Loss:", total_variation(batched_normal_light))

JAX TV Loss: 0.009681859
