Image Quality Assessment library for Jax.
Implementations are Jax.numpy ported versions of the original Numpy-based BasicSR.
Not all functions have been tested. Functions marked as tested below ensure that their output is consistent with BasicSR (MATLAB).
pip install iqa-jax
from iqa.metrics import psnr
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
inputs_1 = jnp.array(np.random.randint(0., 256., size=(16, 256, 256, 3), dtype=np.uint8))
inputs_2 = jnp.array(np.random.randint(0., 256., size=(16, 256, 256, 3), dtype=np.uint8))
metric = jax.jit(partial(psnr, crop_border=0, test_y=False))
psnr_val = metric(inputs_1, inputs_2)
- PSNR
- SSIM
- NIQE
- FID
- PSNR
- SSIM
- NIQE
- FID
- InceptionV3
- RGB2Y Conversion
- RGB2Gray Conversion
- MATLAB's .5 scale bicubic resize