In [None]:
import lightning.pytorch as pl

from src.lightning_classes import UnrolledSystem
from src.data_loader import RGBDataset
from src.utils import format_output, get_dataloader, plot_psnr_stages, plot_results, plot_error_maps

In [None]:
CFAS = sorted(['bayer_GRBG', 'quad_bayer', 'gindele', 'chakrabarti', 'hamilton', 'honda', 'kaizu', 'kodak', 'sony', 'sparse_3', 'wang', 'yamagami', 'yamanaka'])
CFAS_NEW = sorted(['bayer_RGGB', 'lukac', 'xtrans', 'luo'])
IMAGE_PATH = 'images/test/28083.jpg'
NOISE_STD = 0
BATCH_SIZE = 8

In [None]:
dataset = RGBDataset(IMAGE_PATH, CFAS, cfa_variants=0, std=NOISE_STD)
dataloader = get_dataloader(dataset, BATCH_SIZE)

dataset_new = RGBDataset(IMAGE_PATH, CFAS_NEW, cfa_variants=0, std=NOISE_STD)
dataloader_new = get_dataloader(dataset_new, BATCH_SIZE)

experiment = 'bayer_GRBG-chakrabarti-gindele-hamilton-honda-kaizu-kodak-quad_bayer-sony-sparse_3-wang-yamagami-yamanaka-4V'
version = 'version_0'
model = UnrolledSystem.load_from_checkpoint(f'logs/{experiment}/{version}/checkpoints/best.ckpt')

trainer = pl.Trainer(logger=False)

In [None]:
trainer.test(model=model, dataloaders=[dataloader, dataloader_new]);

In [None]:
x_hat_list, x_hat_list_new = trainer.predict(model=model, dataloaders=[dataloader, dataloader_new])

In [None]:
gt_list, x_hat_list = format_output(x_hat_list)
gt_list_new, x_hat_list_new = format_output(x_hat_list_new)

In [None]:
plot_results(gt_list, x_hat_list, CFAS, dataset.cfa_idx, stage=-1)

In [None]:
plot_error_maps(gt_list, x_hat_list, CFAS, dataset.cfa_idx, stage=-1, gain=10)

In [None]:
plot_psnr_stages(gt_list, x_hat_list, CFAS, dataset.cfa_idx)

In [None]:
plot_results(gt_list_new, x_hat_list_new, CFAS_NEW, dataset_new.cfa_idx, stage=-1)

In [None]:
plot_error_maps(gt_list_new, x_hat_list_new, CFAS_NEW, dataset_new.cfa_idx, stage=-1, gain=10)

In [None]:
plot_psnr_stages(gt_list_new, x_hat_list_new, CFAS_NEW, dataset_new.cfa_idx)