In [None]:
import glob

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision.transforms import v2
import lpips

from dataset import GoProDataset
from model import BaseLineUnet
from NAFnet_baseline.Baseline_arch import Baseline

In [None]:
def create_dict_of_files():
    files = []
    files.extend(glob.glob("**/*.py", recursive=True))
    files.extend(glob.glob("**/*.ipynb", recursive=True))
    data = {}
    for file in files:
        with open(file, 'r', encoding='utf-8') as f:
            text_content = f.read()
        data[file] = text_content
    return data

In [None]:
class RandomPerceptualLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),
        )
        self.net2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),
        )
        self.net3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),
        )
        self.net4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),
        )
        for param in self.net.parameters():
            param.requires_grad = False
        for param in self.net2.parameters():
            param.requires_grad = False
        for param in self.net3.parameters():
            param.requires_grad = False
        for param in self.net4.parameters():
            param.requires_grad = False

        self.loss_fn = nn.L1Loss()
    def reset(self):
        self.net.apply(lambda m: torch.nn.init.kaiming_normal_(m.weight) if isinstance(m, nn.Conv2d) else None)

    def forward(self, input, target):
        self.reset()
        input = self.net(input)
        target = self.net(target)
        net_loss = self.loss_fn(input, target)
        input = self.net2(input)
        target = self.net2(target)
        net_loss += self.loss_fn(input, target) * 2
        input = self.net3(input)
        target = self.net3(target)
        net_loss += self.loss_fn(input, target) * 4
        input = self.net4(input)
        target = self.net4(target)
        net_loss += self.loss_fn(input, target) * 8
        return net_loss

In [None]:
class LPIPSLoss(nn.Module):
    def __init__(self, net='alex'):
        super().__init__()
        self.loss = lpips.LPIPS(net=net)
        for param in self.loss.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        # Normalize input and target to [-1, 1]
        input = torch.clamp(input, 0, 1)
        target = torch.clamp(target, 0, 1)
        input = input * 2 - 1
        target = target * 2 - 1
        
        return self.loss(input, target).squeeze().mean()

In [None]:
class WeightedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.lpipsloss = LPIPSLoss()
        self.mseloss = nn.MSELoss()
    def forward(self, input, target):
        return 0.8 * self.mseloss(input, target) + 0.2 * self.lpipsloss(input, target)

In [None]:
class PSNRloss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, pred, target):
        return 10 * torch.log10((torch.max(target) - torch.min(target))**2 / F.mse_loss(pred, target))

In [None]:
class LightningWrapper(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BaseLineUnet()
        self.model_code = create_dict_of_files()
        self.loss_fn = nn.MSELoss()
    
    def on_save_checkpoint(self, checkpoint):
        checkpoint["model_code"] = self.model_code

    def on_load_checkpoint(self, checkpoint):
        self.model_code = checkpoint["model_code"]

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        # if batch_idx == 0:
        #     grid_x = torchvision.utils.make_grid(torch.clamp(x, 0, 1))
        #     self.logger.experiment.add_image("train_x", grid_x, self.global_step)
        #     grid_y = torchvision.utils.make_grid(torch.clamp(y, 0, 1))
        #     self.logger.experiment.add_image("train_y", grid_y, self.global_step)
        #     grid_y_hat = torchvision.utils.make_grid(torch.clamp(y_hat, 0, 1))
        #     self.logger.experiment.add_image("train_y_hat", grid_y_hat, self.global_step)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200, eta_min=1e-6)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,  
                'interval': 'epoch',   
                'frequency': 1,         
                'monitor': 'train_loss',   
                'strict': True        
            }
        }
    
model = LightningWrapper()
summary(model, (16, 3, 256, 256))

In [None]:
train_dataset = GoProDataset()
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=11, persistent_workers=True)

In [None]:
checkpoint_callback = ModelCheckpoint(monitor='train_loss', filename='best-checkpoint-{epoch:02d}-{train_loss:.6f}', save_last=True)
lr_callback = LearningRateMonitor(logging_interval='epoch')
torch.set_float32_matmul_precision("high")
trainer = L.Trainer(max_epochs=200, precision='bf16-mixed', callbacks=[checkpoint_callback, lr_callback])
trainer.fit(model, train_dataloader)