In [2]:
import numpy as np

from os import listdir, path

from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.io import imread
import jax.numpy as jnp
from scico import functional, linop, loss
from scico.optimize import ProximalADMM

from src.forward_model.operators import cfa_operator

In [3]:
CFAS = sorted(['bayer_GRBG', 'quad_bayer', 'gindele', 'chakrabarti', 'hamilton', 'honda', 'kaizu', 'kodak', 'sony', 'sparse_3', 'wang', 'yamagami', 'yamanaka'])
RGB_SPECTRAL_STENCIL = np.array([650, 525, 480])
NOISE_LEVEL = 0

INPUT_DIR = 'input/test/'

res_psnr = {}
res_ssim = {}

In [4]:
for cfa in CFAS:
    psnr_list = []
    ssim_list = []

    for i, image_name in enumerate(listdir(INPUT_DIR)):
        x= imread(path.join(INPUT_DIR, image_name)) / 255
        cfa_op = cfa_operator(cfa, x.shape, RGB_SPECTRAL_STENCIL)
        y = np.clip(cfa_op.direct(x) + np.random.normal(0, NOISE_LEVEL / 100, cfa_op.output_shape), 0, 1)
        x_baseline = jnp.array(cfa_op.adjoint(y))

        def forward_pass(x):
            return jnp.array(cfa_op.direct(x))

        def adjoint_pass(y):
            return jnp.array(cfa_op.adjoint(y))

        C = linop.LinearOperator(input_shape=x.shape, output_shape=x.shape[:-1], eval_fn=forward_pass, adj_fn=adjoint_pass)
        D = linop.FiniteDifference(input_shape=x.shape, append=0, axes=(0, 1))
        A = linop.VerticalStack((C, D))

        g_0 = loss.SquaredL2Loss(y=jnp.array(y))
        g_1 = functional.L21Norm(l2_axis=(0, 3))

        mu, nu = ProximalADMM.estimate_parameters(D)

        lambd = 0.001
        rho = 0.005

        g= functional.SeparableFunctional((g_0, lambd * g_1))

        solver_TV = ProximalADMM(
            f=functional.ZeroFunctional(),
            g=g,
            A=A,
            B=None,
            rho=rho,
            mu=mu,
            nu=nu,
            x0=x_baseline,
            maxiter=1000
        )

        res = np.clip(np.array(solver_TV.solve()), 0, 1)
        psnr_list.append(peak_signal_noise_ratio(x, res, data_range=1))
        ssim_list.append(structural_similarity(x, res, data_range=1, channel_axis=2))

    psnr_mean = np.mean(psnr_list)
    psnr_std = np.std(psnr_list)
    ssim_mean = np.mean(ssim_list)
    ssim_std = np.std(ssim_list)

    res_psnr[cfa] = f'{psnr_mean:.2f}dB +/- {psnr_std:.2f}dB'
    res_ssim[cfa] = f'{ssim_mean:.2f} +/- {ssim_std:.2f}'

In [None]:
print(res_psnr)
print(res_ssim)