In [1]:
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.tuner import Tuner
from skimage.metrics import peak_signal_noise_ratio
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from src.forward_operator.operators import cfa_operator
from src.lightning_classes import U_PDHG_system, DataModule
from src.data_loader import RGBDataset, RGB_SPECTRAL_STENCIL

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN_FLAG = True
CFA = 'bayer'
TRAIN_DIR = 'src/images/train'
TEST_DIR = 'src/images/test'
VAL_DIR = 'src/images/val'
NB_STAGES = 8
LEARNING_RATE = 1e-3
NB_EPOCHS = 2

OP = cfa_operator(CFA, [26, 40, 3], RGB_SPECTRAL_STENCIL, 'dirac')

In [3]:
train_dataset = RGBDataset(TRAIN_DIR, OP.direct)
test_dataset = RGBDataset(TEST_DIR, OP.direct)
val_dataset = RGBDataset(VAL_DIR, OP.direct)

data_module = DataModule(train_dataset, val_dataset, test_dataset)

model = U_PDHG_system(LEARNING_RATE, NB_STAGES, CFA, RGB_SPECTRAL_STENCIL, 3)

In [4]:
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, mode='min')
save_best = ModelCheckpoint(filename='best', monitor='val_loss')
trainer = pl.Trainer(max_epochs=NB_EPOCHS, callbacks=[early_stop, save_best])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
tuner = Tuner(trainer)

tuner.scale_batch_size(model, datamodule=data_module, init_val=16)
tuner.lr_find(model, datamodule=data_module, early_stop_threshold=None)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_steps=3` reached.
Batch size 16 succeeded, trying batch size 32
`Trainer.fit` stopped: `max_steps=3` reached.
Batch size 32 succeeded, trying batch size 64
`Trainer.fit` stopped: `max_steps=3` reached.
Batch size 64 succeeded, trying batch size 128
Batch size 128 failed, trying batch size 64
Finished batch size finder, will continue with full run using batch size 64
Restoring states from the checkpoint path at /home/mullemat/code/unrolled_demosaicking/.scale_batch_size_92515ae1-6a61-404c-b7ba-5c63953f0900.ckpt
Restored all states from the checkpoint at /home/mullemat/code/unrolled_demosaicking/.scale_batch_size_92515ae1-6a61-404c-b7ba-5c63953f0900.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Finding best initial lr:   8%|▊         | 8/100 [00:08<01:18,  1.17it/s]`Trainer.fit` stopped: `max_epochs=2` reached.
LR finder stopped early after 8 steps due to diverging loss.
Finding best initial lr:   8%|▊         | 8/100

<lightning.pytorch.tuner.lr_finder._LRFinder at 0x7f407818a090>

In [None]:
trainer.fit(model, data_module)

In [None]:
trainer.test(model, data_module)

In [None]:
model = U_PDHG_system.load_from_checkpoint('lightning_logs/version_0/checkpoints/best.ckpt')
model.eval()

img = Image.open('src/images/val/3096.jpg')
x = np.array(img.resize((img.size[0] // 12, img.size[1] // 12))).astype(np.float32) / 255
y = torch.tensor(OP.direct(x), dtype=torch.float, device=DEVICE)

with torch.no_grad():
    x_hat = model(y[None])[0].cpu().detach().numpy()

In [None]:
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)

print(peak_signal_noise_ratio(x, x_hat))
axs[0, 0].imshow(x)
axs[0, 0].set_title('Ground truth')
axs[0, 1].imshow(y.cpu(), cmap='gray')
axs[0, 1].set_title('Input')
axs[1, 0].imshow(OP.adjoint(OP.direct(x)))
axs[1, 0].set_title('Adjoint of the input')
axs[1, 1].imshow(x_hat)
axs[1, 1].set_title('Reconstruction')