In [1]:
import numpy as np
from os import listdir, path
import csv
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.io import imread
import jax.numpy as jnp
from scico import functional, linop, loss
from scico.optimize import ProximalADMM

from src.forward_model.cfa_operator import cfa_operator

In [2]:
CFAS_OLD = sorted(['bayer_GRBG', 'quad_bayer', 'gindele', 'chakrabarti', 'hamilton', 'honda', 'kaizu', 'kodak', 'sony', 'sparse_3', 'wang', 'yamagami', 'yamanaka'])
CFAS_NEW = sorted(['bayer_RGGB', 'lukac', 'xtrans', 'luo'])
CFAS = CFAS_OLD + CFAS_NEW
RGB_SPECTRAL_STENCIL = np.array([650, 525, 480])
NOISE_LEVEL = 0

IMG_DIR = 'input/test/'

psnr = lambda x, x_hat: peak_signal_noise_ratio(x, x_hat, data_range=1)
ssim = lambda x, x_hat: structural_similarity(x, x_hat, data_range=1, channel_axis=2)

res = []
images = [imread(path.join(IMG_DIR, img_name)) / 255 for img_name in listdir(IMG_DIR)]

In [3]:
def reconstruction(cfa, x):
    cfa_op = cfa_operator(cfa, x.shape, RGB_SPECTRAL_STENCIL)
    y = np.clip(cfa_op.direct(x) + np.random.normal(0, NOISE_LEVEL / 100, cfa_op.output_shape), 0, 1)
    x_baseline = jnp.array(cfa_op.adjoint(y))

    def forward_pass(x):
        return jnp.array(cfa_op.direct(x))

    def adjoint_pass(y):
        return jnp.array(cfa_op.adjoint(y))

    C = linop.LinearOperator(input_shape=x.shape, output_shape=x.shape[:-1], eval_fn=forward_pass, adj_fn=adjoint_pass)
    D = linop.FiniteDifference(input_shape=x.shape, append=0, axes=(0, 1))
    A = linop.VerticalStack((C, D))

    g_0 = loss.SquaredL2Loss(y=jnp.array(y))
    g_1 = functional.L21Norm(l2_axis=(0, 3))

    mu, nu = ProximalADMM.estimate_parameters(D)

    lambd = 0.001
    rho = 0.005

    g= functional.SeparableFunctional((g_0, lambd * g_1))

    solver_TV = ProximalADMM(
        f=functional.ZeroFunctional(),
        g=g,
        A=A,
        B=None,
        rho=rho,
        mu=mu,
        nu=nu,
        x0=x_baseline,
        maxiter=400
    )

    return np.clip(np.array(solver_TV.solve()), 0, 1)

In [4]:
for cfa in tqdm(CFAS):
    outputs = [reconstruction(cfa, x) for x in images]
    metrics = np.array([[psnr(x, output), ssim(x, output)] for x, output in zip(images, outputs)]).T
    res.append([cfa, f'${np.mean(metrics[0]):.2f} \pm {np.std(metrics[0]):.2f}$',
                f'${np.mean(metrics[1]):.3f} \pm {np.std(metrics[1]):.3f}$'])

100%|██████████| 17/17 [59:12<00:00, 208.99s/it]


In [5]:
with open('metrics.csv', 'w') as csvFile:
    writer = csv.writer(csvFile)
    writer.writerows(res)