In [None]:
import lightning.pytorch as pl

from src.lightning_classes import UnrolledSystem
from src.data_loader import RGBDataset
from src.utils import format_output, get_dataloader, plot_psnr_stages, plot_results

In [None]:
CFAS_TRAIN = ['bayer_GRBG', 'binning', 'chakrabarti', 'gindele', 'hamilton', 'honda', 'kaizu', 'kodak', 'quad_bayer', 'random', 'sparse_3', 'wang', 'yamagami', 'yamanaka']
CFAS_TEST = ['honda2', 'lukac', 'luo', 'sony', 'xtrans']
IMAGE_PATH = 'input/28083.jpg'
BATCH_SIZE = 8

In [None]:
dataset = RGBDataset(IMAGE_PATH, CFAS_TRAIN, cfa_variants=0)
dataloader = get_dataloader(dataset, BATCH_SIZE)

dataset_test = RGBDataset(IMAGE_PATH, CFAS_TEST, cfa_variants=0)
dataloader_test = get_dataloader(dataset_test, BATCH_SIZE)

experiment = 'bayer_GRBG-binning-chakrabarti-gindele-hamilton-honda-kaizu-kodak-quad_bayer-sparse_3-wang-yamagami-yamanaka-6V'
version = 'version_0'
model = UnrolledSystem.load_from_checkpoint(f'weights/{experiment}/{version}/checkpoints/best.ckpt')

trainer = pl.Trainer(logger=False)

In [None]:
trainer.test(model=model, dataloaders=[dataloader, dataloader_test]);

In [None]:
x_hat_list, x_hat_list_new = trainer.predict(model=model, dataloaders=[dataloader, dataloader_test])

In [None]:
gt_list, x_hat_list = format_output(x_hat_list)
gt_list_new, x_hat_list_new = format_output(x_hat_list_new)

In [None]:
plot_results(gt_list, x_hat_list, CFAS_TRAIN, dataset.cfa_idx, stage=-1)

In [None]:
plot_psnr_stages(gt_list, x_hat_list, CFAS_TRAIN, dataset.cfa_idx)

In [None]:
plot_results(gt_list_new, x_hat_list_new, CFAS_TEST, dataset_test.cfa_idx, stage=-1)

In [None]:
plot_psnr_stages(gt_list_new, x_hat_list_new, CFAS_TEST, dataset_test.cfa_idx)