In [None]:
import glob

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision.transforms import v2


from dataset import GoProDataset
from model import BaseLineUnet
from NAFnet_baseline.Baseline_arch import Baseline

In [None]:
def create_dict_of_files():
    files = []
    files.extend(glob.glob("**/*.py", recursive=True))
    files.extend(glob.glob("**/*.ipynb", recursive=True))
    data = {}
    for file in files:
        with open(file, 'r', encoding='utf-8') as f:
            text_content = f.read()
        data[file] = text_content
    return data

In [None]:
from losses import LPIPSLoss
class WeightedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        self.lpipsloss = LPIPSLoss()
    def forward(self, input, target):
        return 0.8 * self.mse(input, target) + 0.2 * self.lpipsloss(input, target)

In [None]:
from losses import ResizeLoss
class LightningWrapper(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BaseLineUnet()
        self.model_code = create_dict_of_files()
        self.loss_fn = ResizeLoss()
    
    def on_save_checkpoint(self, checkpoint):
        checkpoint["model_code"] = self.model_code

    def on_load_checkpoint(self, checkpoint):
        self.model_code = checkpoint["model_code"]

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        # if batch_idx == 0:
        #     grid_x = torchvision.utils.make_grid(torch.clamp(x, 0, 1))
        #     self.logger.experiment.add_image("train_x", grid_x, self.global_step)
        #     grid_y = torchvision.utils.make_grid(torch.clamp(y, 0, 1))
        #     self.logger.experiment.add_image("train_y", grid_y, self.global_step)
        #     grid_y_hat = torchvision.utils.make_grid(torch.clamp(y_hat, 0, 1))
        #     self.logger.experiment.add_image("train_y_hat", grid_y_hat, self.global_step)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,  
                'interval': 'epoch',   
                'frequency': 1,         
                'monitor': 'train_loss',   
                'strict': True        
            }
        }
    
model = LightningWrapper()
summary(model, (16, 3, 256, 256))

In [None]:
LOAD_PRETRAINED = False
if LOAD_PRETRAINED:
    checkpoint = torch.load("")
    model_weights = checkpoint["state_dict"]
    for key in list(model_weights):
        model_weights[key.replace("model.", "")] = model_weights.pop(key)
    for key in list(model_weights):
        if key.startswith("loss_fn."):
            model_weights.pop(key)
    model.model.load_state_dict(model_weights)

In [None]:
train_dataset = GoProDataset(root_dir='E:\\Downloads\\GOPRO_Large\\train', addnoise=False, mode='train')
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=11, persistent_workers=True)

In [None]:
checkpoint_callback = ModelCheckpoint(monitor='train_loss', filename='best-checkpoint-{epoch:02d}-{train_loss:.6f}', save_last=True)
lr_callback = LearningRateMonitor(logging_interval='epoch')
torch.set_float32_matmul_precision("high")
trainer = L.Trainer(max_epochs=500, precision='bf16-mixed', callbacks=[checkpoint_callback, lr_callback])
trainer.fit(model, train_dataloader)