In [None]:
from dataset import TorinoAquaDataset
from model import AutoEncoder, Unet
import os
import torch
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from utils import DiceLoss
import torchvision

LOAD_MODEL = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Unet().to(device)
lossfn1 = DiceLoss()
lossfn2 = nn.L1Loss()
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'min', 0.8, 3)

dataset = TorinoAquaDataset()
dataloader = DataLoader(dataset, 1, True, num_workers=2)
epochs = 50

os.makedirs('TrainImg', exist_ok=True)
if LOAD_MODEL:
    model.load_state_dict(torch.load('model.pth'))

loss_sum_min = 2e9
for epoch in range(epochs):
    print('Epoch:', epoch+1)
    loss_sum = 0
    for i, data in enumerate(tqdm(dataloader)):
        inputs = data['input'].to(device)
        mask = data['mask'].to(device)
        label = data['label'].to(device)
        logits = model(inputs)
        loss = lossfn1(logits[:,0,:,:], mask) + lossfn2(logits[:,1:,:,:], label)
        optim.zero_grad()
        loss.backward()
        loss_sum += loss.item()
        optim.step()
        if (i+1)%200==0:
            print(loss.item())
            grid = torchvision.utils.make_grid([
                inputs[0],
                mask[0].repeat(3, 1, 1),
                label[0],
                logits[0,0,:,:].detach().repeat(3, 1, 1),
                logits[0,1:,:,:].detach()], 
                5
            )
            torchvision.utils.save_image(
                grid,
                f'TrainImg/Epoch{epoch+1}_{i+1}.png'
            )

    scheduler.step(loss_sum)
    if loss_sum<loss_sum_min:
        loss_sum_min = loss_sum
        torch.save(model.state_dict(), 'model.pth')
        print("Model saved!")