In [None]:
import torch
import lightning.pytorch as pl
from skimage.metrics import peak_signal_noise_ratio
from skimage.io import imread
import numpy as np
import matplotlib.pyplot as plt

from src.forward_model.operators import cfa_operator
from src.lightning_classes import UnrolledSystem, DataModule
from src.data_loader import RGBDataset

In [None]:
CFAS = sorted(['bayer_GRBG', 'quad_bayer', 'sony', 'kodak'])
TEST_DIR = 'images/test'
BATCH_SIZE = 16

In [None]:
test_dataset = RGBDataset(TEST_DIR, CFAS)
data_module = DataModule(BATCH_SIZE, test_dataset=test_dataset)

experiment = 'bayer_GRBG-kodak-quad_bayer-sony-4'
version = 'version_0'
model = UnrolledSystem.load_from_checkpoint(f'logs/{experiment}/{version}/checkpoints/best.ckpt')

trainer = pl.Trainer(logger=False)

In [None]:
trainer.test(model=model, datamodule=data_module)

In [None]:
def prepare_input(path, cfas):
    x = imread(path) / 255
    inputs = []

    for cfa in cfas:
        op = cfa_operator(cfa, x.shape, [650, 525, 480])
        y = torch.Tensor(op.direct(x))[None]
        mask = torch.Tensor(op.cfa_mask).permute(2, 0, 1)
        inputs.append(torch.cat([y, mask]))

    return x.astype(np.float32), torch.stack(inputs)

In [None]:
path = 'images/test/28083.jpg'
x, input_data = prepare_input(path, CFAS)

model.eval()
with torch.no_grad():
    x_hat_list = np.clip(torch.stack(model(input_data)).permute(1, 0, 3, 4, 2).numpy(force=True), 0, 1)

In [None]:
x = x[2:-2, 2:-2]
x_hat_list = x_hat_list[:, :, 2:-2, 2:-2]

In [None]:
iter = -1
nb_images = len(x_hat_list)
nb_cols = -(-nb_images // 2)

fig, axs = plt.subplots(2, nb_cols, figsize=(20, 12))

for i in range(nb_images):
    axs[i // nb_cols, i % nb_cols].imshow(x_hat_list[i, iter])
    axs[i // nb_cols, i % nb_cols].set_title(f'CFA: {CFAS[i]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[i, iter]):.2f}')
    axs[i // nb_cols, i % nb_cols].axis('off')

plt.show()

In [None]:
for i in range(len(CFAS)):
    plt.plot([peak_signal_noise_ratio(x, x_hat) for x_hat in x_hat_list[i]], label=CFAS[i])

plt.title('PSNR in functions of the stages')
plt.xlabel('Stages')
plt.ylabel('PSNR (dB)')
plt.legend()
plt.show()