In [None]:
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from src.lightning_classes import UnrolledSystem, DataModule
from src.data_loader import RGBDataset

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CFAS = ['bayer', 'kodak']
TRAIN_DIR = 'images/train'
VAL_DIR = 'images/val'
TEST_DIR = 'images/test'
PATCH_SIZE = 32
NB_STAGES = 8
NB_CHANNELS = 16
BATCH_SIZE = 512
LEARNING_RATE = 1e-1
NB_EPOCHS = 100

In [None]:
train_dataset = RGBDataset(TRAIN_DIR, CFAS, PATCH_SIZE, PATCH_SIZE // 2)
val_dataset = RGBDataset(VAL_DIR, CFAS, PATCH_SIZE, PATCH_SIZE)
test_dataset = RGBDataset(TEST_DIR, CFAS, PATCH_SIZE, PATCH_SIZE)
data_module = DataModule(train_dataset, val_dataset, test_dataset, BATCH_SIZE)

model = UnrolledSystem(LEARNING_RATE, NB_STAGES, NB_CHANNELS)

early_stop = EarlyStopping(monitor='Loss/Val', min_delta=1e-5, patience=10)
save_best = ModelCheckpoint(filename='best', monitor='Loss/Val')
trainer = pl.Trainer(max_epochs=NB_EPOCHS, callbacks=[early_stop, save_best])

In [None]:
trainer.fit(model, datamodule=data_module)