In [1]:
import lightning.pytorch as pl
import numpy as np
import csv
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

from src.lightning_classes import UnrolledSystem
from src.data_loader import RGBDataset
from src.utils import format_output, get_dataloader

In [None]:
CFAS_OLD = sorted(['bayer_GRBG', 'gindele', 'chakrabarti', 'hamilton', 'honda', 'kaizu', 'kodak', 'sparse_3', 'wang', 'yamagami', 'yamanaka'])
CFAS_NEW = sorted(['quad_bayer', 'lukac', 'xtrans', 'sony'])
CFAS = CFAS_OLD + CFAS_NEW
TEST_DIR = 'images/test'
NOISE_STD = 0
BATCH_SIZE = 16

psnr = lambda x, x_hat: peak_signal_noise_ratio(x, x_hat, data_range=1)
ssim = lambda x, x_hat: structural_similarity(x, x_hat, data_range=1, channel_axis=2)

experiment = 'bayer_GRBG-chakrabarti-gindele-hamilton-honda-kaizu-kodak-quad_bayer-sony-sparse_3-wang-yamagami-yamanaka-4V'
version = 'version_0'
model = UnrolledSystem.load_from_checkpoint(f'logs/{experiment}/{version}/checkpoints/best.ckpt')

trainer = pl.Trainer(logger=False)

In [None]:
res = []

for cfa in CFAS:
    dataset = RGBDataset(TEST_DIR, [cfa], cfa_variants=0, std=NOISE_STD)
    dataloader = get_dataloader(dataset, BATCH_SIZE)

    x_hat_list = trainer.predict(model=model, dataloaders=dataloader)

    gt_list, x_hat_list = format_output(x_hat_list)
    x_hat_list = x_hat_list[-1]

    metrics = np.array([[psnr(x, y), ssim(x, y)] for x, y in zip(gt_list, x_hat_list)]).T

    res.append([cfa] + [f'${np.mean(metrics[0]):.2f} \pm {np.std(metrics[0]):.2f}$',
                        f'${np.mean(metrics[1]):.3f} \pm {np.std(metrics[1]):.3f}$'])

In [None]:
with open(f'metrics{"_V" if experiment.endswith("V") else ""}.csv', 'w') as csvFile:
    writer = csv.writer(csvFile)
    writer.writerows(res)