In [None]:
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.loggers import TensorBoardLogger
from skimage.metrics import peak_signal_noise_ratio, mean_squared_error
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from src.forward_operator.operators import cfa_operator
from src.baseline_method.inversion_baseline import Inverse_problem
from src.lightning_classes import U_PDHG_system, DataModule
from src.data_loader import RGBDataset, RGB_SPECTRAL_STENCIL

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CFA = 'bayer'
TRAIN_DIR = 'images/train'
VAL_DIR = 'images/val'
TEST_DIR = 'images/test'
SCALE = 2
NB_STAGES = 6
NB_CHANNELS = 32
KERNEL_SIZE = 3
BATCH_SIZE = 56
LEARNING_RATE = 1e-2
NB_EPOCHS = 500

OP = cfa_operator(CFA, [160, 240, 3], RGB_SPECTRAL_STENCIL, 'dirac')
baseline_inversion = Inverse_problem(CFA, [160, 240, 3], RGB_SPECTRAL_STENCIL, 'dirac')

In [None]:
train_dataset = RGBDataset(TRAIN_DIR, SCALE, OP.direct, OP.adjoint)
val_dataset = RGBDataset(VAL_DIR, SCALE, OP.direct, OP.adjoint)
test_dataset = RGBDataset(TEST_DIR, SCALE, OP.direct, OP.adjoint)

data_module = DataModule(train_dataset, val_dataset, test_dataset, BATCH_SIZE)

model = U_PDHG_system(LEARNING_RATE, NB_STAGES, CFA, RGB_SPECTRAL_STENCIL, NB_CHANNELS, KERNEL_SIZE)

In [None]:
logger = TensorBoardLogger('tb_logs', default_hp_metric=False)

early_stop = EarlyStopping(monitor='val_loss', min_delta=1e-6, patience=50)
lr_monitor = LearningRateMonitor()
save_best = ModelCheckpoint(filename='best', monitor='val_loss')
trainer = pl.Trainer(max_epochs=NB_EPOCHS, callbacks=[early_stop, lr_monitor, save_best], logger=logger)

# tuner = Tuner(trainer)
# tuner.lr_find(model, datamodule=data_module)
# print(model.lr)
# tuner.scale_batch_size(model, datamodule=data_module, init_val=16)

In [None]:
trainer.fit(model, datamodule=data_module)

In [None]:
trainer.test(model, datamodule=data_module)

In [None]:
model = U_PDHG_system.load_from_checkpoint(f'tb_logs/lightning_logs/version_{trainer.logger.version}/checkpoints/best.ckpt')
model.eval()

img = Image.open('images/val/3096.jpg')
x = np.array(img.resize((img.size[0] // SCALE, img.size[1] // SCALE))) / 255
y = OP.direct(x)
x_baseline = baseline_inversion(y)
input_data = torch.tensor(np.concatenate((y[:, :, None], x_baseline), axis=2), dtype=torch.float, device=DEVICE)[None]

with torch.no_grad():
    x_hat = model(input_data)[0].numpy(force=True).astype(float)

In [None]:
print(f'Baseline:\n\tMSE: {mean_squared_error(x, baseline_inversion(y)):.6f}, PSNR: {peak_signal_noise_ratio(x, baseline_inversion(y)):.2f}')
print(f'UPDGH:\n\tMSE: {mean_squared_error(x, x_hat):.6f}, PSNR: {peak_signal_noise_ratio(x, x_hat):.2f}')

fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)

axs[0, 0].imshow(x)
axs[0, 0].set_title('Ground truth')
axs[0, 1].imshow(y, cmap='gray')
axs[0, 1].set_title('Input')
axs[1, 0].imshow(x_baseline)
axs[1, 0].set_title('Baseline')
axs[1, 1].imshow(x_hat)
axs[1, 1].set_title('UPDGH')
plt.show()