In [5]:
import lightning as L
import torch
from torch.utils.data import DataLoader
from model import BaseLineUnet
from dataset import GoProDataset
import torch.nn.functional as F
from lightning.pytorch.callbacks import ModelCheckpoint
from torchinfo import summary
from torchvision.transforms import v2

In [6]:
class L_BaseLineUnet(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BaseLineUnet()
    def forward(self, x):
        return self.model(x)
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.l1_loss(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        return loss
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,  
                'interval': 'epoch',   
                'frequency': 1,         
                'monitor': 'train_loss',   
                'strict': True        
            }
        }
    
model = L_BaseLineUnet()

In [7]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])
train_dataset = GoProDataset(transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=11, persistent_workers=True)

In [4]:
checkpoint_callback = ModelCheckpoint(monitor='train_loss', filename='best-checkpoint-{epoch:02d}-{train_loss:.2f}', save_last=True)
torch.set_float32_matmul_precision("medium")
trainer = L.Trainer(max_epochs=300, precision='bf16-mixed', callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloader)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type         | Params | Mode 
-----------------------------------------------
0 | model | BaseLineUnet | 20.6 M | train
-----------------------------------------------
20.6 M    Trainable params
0         Non-trainable params
20.6 M    Total params
82.445    Total estimated model params size (MB)
410       Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]