In [None]:
import sys
sys.path.append('/content/PixelatedImageSegmentation')

from dataset import TorinoAquaDataset
from model import AutoEncoder
import os
import torch
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from utils import DiceLoss

LOAD_MODEL = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoEncoder().to(device)
lossfn = DiceLoss()
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'min', 0.8, 3)

dataset = TorinoAquaDataset()
dataloader = DataLoader(dataset, 1, True, num_workers=os.cpu_count())
epochs = 50

os.makedirs('TrainImg', exist_ok=True)
if LOAD_MODEL:
    model.load_state_dict(torch.load('model.pth'))

loss_sum_min = 2e9
for epoch in range(epochs):
    print('Epoch:', epoch+1)
    loss_sum = 0
    for i, data in enumerate(tqdm(dataloader)):
        inputs = data['input'].to(device)
        label = data['mask'].to(device)
        logits = model(inputs)
        loss = lossfn(F.sigmoid(logits), label)
        optim.zero_grad()
        loss.backward()
        loss_sum += loss.item()
        optim.step()
        if (i+1)%50==0:
            print(loss.item())
            fig, axes = plt.subplots(1, 3, figsize=(50,20))
            axes[0].set_title('input')
            axes[0].imshow(inputs[0].cpu().numpy().transpose(1,2,0))
            axes[1].set_title('label')
            axes[1].imshow(label[0].cpu().numpy().transpose(1,2,0))
            axes[2].set_title('predict')
            axes[2].imshow(F.sigmoid(logits)[0].cpu().detach().numpy().transpose(1,2,0))
            plt.savefig(f'TrainImg/Epoch{epoch+1}_{i+1}')
            plt.close()
    scheduler.step(loss_sum)
    if loss_sum<loss_sum_min:
        loss_sum_min = loss_sum
        torch.save(model.state_dict(), 'model.pth')
        print("Model saved!")