In [None]:
import torch
import lightning.pytorch as pl
from skimage.metrics import peak_signal_noise_ratio
from skimage.io import imread
import numpy as np
import matplotlib.pyplot as plt

from src.forward_operator.operators import cfa_operator
from src.lightning_classes import UnrolledSystem, DataModule
from src.data_loader import RGBDataset

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CFAS = ['bayer', 'quad_bayer', 'sony', 'kodak', 'sparse_3']
TEST_DIR = 'images/test'
BATCH_SIZE = 16

In [None]:
test_dataset = RGBDataset(TEST_DIR, CFAS)
data_module = DataModule(BATCH_SIZE, test_dataset=test_dataset)

version = 'all'
model = UnrolledSystem.load_from_checkpoint(f'lightning_logs/{version}/checkpoints/best.ckpt')

trainer = pl.Trainer()

In [None]:
trainer.test(model, datamodule=data_module)

In [None]:
def prepare_input(path, cfas):
    x = imread(path) / 255
    inputs = []

    for cfa in cfas:
        op = cfa_operator(cfa, x.shape, [650, 525, 480], 'dirac')
        y = torch.Tensor(op.direct(x))[None]
        mask = torch.Tensor(op.cfa_mask).permute(2, 0, 1)
        inputs.append(torch.cat([y, mask]))

    return x.astype(np.float32), torch.stack(inputs)

In [None]:
path = 'images/test/28083.jpg'
x, input_data = prepare_input(path, CFAS)

model.eval()
with torch.no_grad():
    x_hat_list = np.clip(torch.stack(model(input_data)).permute(1, 0, 3, 4, 2).numpy(force=True), 0, 1)

In [None]:
x = x[2:-2, 2:-2]
x_hat_list = x_hat_list[:, :, 2:-2, 2:-2]

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(20, 12))
iter = -1

axs[0, 0].imshow(x)
axs[0, 0].set_title('Ground truth')
axs[0, 0].axis('off')
axs[0, 0].imshow(x_hat_list[0, iter])
axs[0, 0].set_title(f'CFA: {CFAS[0]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[0, iter]):.2f}')
axs[0, 0].axis('off')
axs[0, 1].imshow(x_hat_list[1, iter])
axs[0, 1].set_title(f'CFA: {CFAS[1]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[1, iter]):.2f}')
axs[0, 1].axis('off')
axs[0, 2].imshow(x_hat_list[2, iter])
axs[0, 2].set_title(f'CFA: {CFAS[2]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[2, iter]):.2f}')
axs[0, 2].axis('off')
axs[1, 0].imshow(x_hat_list[3, iter])
axs[1, 0].set_title(f'CFA: {CFAS[3]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[3, iter]):.2f}')
axs[1, 0].axis('off')
axs[1, 1].imshow(x_hat_list[4, iter])
axs[1, 1].set_title(f'CFA: {CFAS[4]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[4, iter]):.2f}')
axs[1, 1].axis('off')
plt.show()

In [None]:
for i in range(len(CFAS)):
    plt.plot([peak_signal_noise_ratio(x, x_hat) for x_hat in x_hat_list[i]], label=CFAS[i])

plt.title('PSNR in functions of the stages')
plt.xlabel('Stages')
plt.ylabel('PSNR (dB)')
plt.legend()
plt.show()