In [1]:
import lightning as L
import torch
from torch.utils.data import DataLoader
from model import BaseLineUnet
from dataset import GoProDataset
import torch.nn.functional as F
from lightning.pytorch.callbacks import ModelCheckpoint
from torchinfo import summary
from torchvision.transforms import v2
from lightning.pytorch.callbacks import LearningRateMonitor
from NAFnet_baseline.Baseline_arch import Baseline
import glob

In [4]:
def create_dict_of_files():
    files = []
    files.extend(glob.glob("**/*.py", recursive=True))
    files.extend(glob.glob("**/*.ipynb", recursive=True))
    data = {}
    for file in files:
        with open(file, 'r', encoding='utf-8') as f:
            text_content = f.read()
        data[file] = text_content
    return data

In [None]:
class LightningWrapper(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BaseLineUnet()
        self.model_code = create_dict_of_files()
        # img_channel = 3
        # width = 32

        # dw_expand = 1
        # ffn_expand = 2

        # enc_blks = [1, 1, 1, 28]
        # middle_blk_num = 1
        # dec_blks = [1, 1, 1, 1]

        # self.model = Baseline(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
        #             enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, dw_expand=dw_expand, ffn_expand=ffn_expand)
    
    def on_save_checkpoint(self, checkpoint):
        checkpoint["model_code"] = self.model_code

    def on_load_checkpoint(self, checkpoint):
        self.model_code = checkpoint["model_code"]

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.mse_loss(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200, eta_min=1e-6)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,  
                'interval': 'epoch',   
                'frequency': 1,         
                'monitor': 'train_loss',   
                'strict': True        
            }
        }
    
model = LightningWrapper()
summary(model, (16, 3, 256, 256))

In [6]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])
train_dataset = GoProDataset('E:\\Downloads\\GOPRO_Large\\train', transform=transform, size=256)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=11, persistent_workers=True)

In [None]:
checkpoint_callback = ModelCheckpoint(monitor='train_loss', filename='best-checkpoint-{epoch:02d}-{train_loss:.4f}', save_last=True)
lr_callback = LearningRateMonitor(logging_interval='epoch')
torch.set_float32_matmul_precision("high")
trainer = L.Trainer(max_epochs=200, precision='bf16-mixed', callbacks=[checkpoint_callback, lr_callback])
trainer.fit(model, train_dataloader)