In [None]:
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from skimage.metrics import peak_signal_noise_ratio
from skimage.io import imread
import numpy as np
import matplotlib.pyplot as plt

from src.forward_operator.operators import cfa_operator
from src.lightning_classes import UnrolledSystem, DataModule
from src.data_loader import RGBDataset

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CFAS = ['bayer', 'quad_bayer', 'kodak']
TRAIN_DIR = 'images/train'
VAL_DIR = 'images/val'
TEST_DIR = 'images/test'
PATCH_SIZE = 32
NB_STAGES = 8
NB_CHANNELS = 16
BATCH_SIZE = 512
LEARNING_RATE = 1e-1
NB_EPOCHS = 100

In [None]:
train_dataset = RGBDataset(TRAIN_DIR, CFAS, PATCH_SIZE, PATCH_SIZE // 2)
val_dataset = RGBDataset(VAL_DIR, CFAS, PATCH_SIZE, PATCH_SIZE)
test_dataset = RGBDataset(TEST_DIR, CFAS, PATCH_SIZE, PATCH_SIZE)

data_module = DataModule(train_dataset, val_dataset, test_dataset, BATCH_SIZE)
model = UnrolledSystem(LEARNING_RATE, NB_STAGES, NB_CHANNELS)

In [None]:
early_stop = EarlyStopping(monitor='Loss/Val', min_delta=1e-5, patience=10)
save_best = ModelCheckpoint(filename='best', monitor='Loss/Val')
trainer = pl.Trainer(max_epochs=NB_EPOCHS, callbacks=[early_stop, save_best])

In [None]:
trainer.fit(model, datamodule=data_module)

In [None]:
trainer.test(model, datamodule=data_module)

In [None]:
def prepare_input(path, cfas):
    x = imread(path) / 255
    inputs = []

    for cfa in cfas:
        op = cfa_operator(cfa, x.shape, [650, 525, 480], 'dirac')
        y = torch.Tensor(op.direct(x))[None]
        mask = torch.Tensor(op.cfa_mask).permute(2, 0, 1)
        inputs.append(torch.cat([y, mask]))

    return x, torch.stack(inputs).to(DEVICE)

In [None]:
model = UnrolledSystem.load_from_checkpoint(f'lightning_logs/version_{0}/checkpoints/best.ckpt')

In [None]:
path = 'images/test/28083.jpg'
x, input_data = prepare_input(path, CFAS)

model.eval()
with torch.no_grad():
    # x_hat_list.shape == (cfas, stages, height, width, channels)
    x_hat_list = np.clip(torch.stack(model(input_data)).permute(1, 0, 3, 4, 2).numpy(force=True).astype(float), 0, 1)

In [None]:
x = x[2:-2, 2:-2]
x_hat_list = x_hat_list[:, :, 2:-2, 2:-2]

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(16, 11))
iter = -1

axs[0, 0].imshow(x)
axs[0, 0].set_title('Ground truth')
axs[0, 0].axis('off')
axs[0, 1].imshow(x_hat_list[0, iter])
axs[0, 1].set_title(f'CFA: {CFAS[0]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[0, iter]):.2f}')
axs[0, 1].axis('off')
axs[1, 0].imshow(x_hat_list[1, iter])
axs[1, 0].set_title(f'CFA: {CFAS[1]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[1, iter]):.2f}')
axs[1, 0].axis('off')
axs[1, 1].imshow(x_hat_list[2, iter])
axs[1, 1].set_title(f'CFA: {CFAS[2]}, PSNR: {peak_signal_noise_ratio(x, x_hat_list[2, iter]):.2f}')
axs[1, 1].axis('off')
plt.show()

In [None]:
for i in range(len(CFAS)):
    plt.plot([peak_signal_noise_ratio(x, x_hat) for x_hat in x_hat_list[i]], label=CFAS[i])

plt.title('PSNR in functions of the stages')
plt.xlabel('Stages')
plt.ylabel('PSNR')
plt.legend()
plt.show()