In [None]:
from dataset import TorinoAquaDataset
from models.Unet import Discriminator
from models.Lama import Lama
import os
import torch
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from utils import MaskL1Loss, DiceLoss, VGGLoss
import torchvision
from torch.utils.tensorboard import SummaryWriter

LOAD_CHECKPOINT = True
SAVE_EVERY_EPOCH = 1

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Lama(1/8, 4, 3).to(device)
modelD = Discriminator().to(device)
lossfn = nn.MSELoss()
vggloss = VGGLoss(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
optimD = torch.optim.Adam(modelD.parameters(), lr=1e-4)

plot_array = []
plot_arrayD = []
start_epoch = 0

dataset = TorinoAquaDataset(no_mask=-1)
dataloader = DataLoader(dataset, 1, True, num_workers=1)

writer = SummaryWriter()

os.makedirs('TrainImg', exist_ok=True)
os.makedirs('TrainCheckpoint', exist_ok=True)

if LOAD_CHECKPOINT and os.path.exists('TrainCheckpoint/'+model.__class__.__name__+'CheckpointLast.pth'):
    checkpoint_file = torch.load('TrainCheckpoint/'+model.__class__.__name__+'CheckpointLast.pth')
    model.load_state_dict(checkpoint_file['model'])
    modelD.load_state_dict(checkpoint_file['modelD'])
    optim.load_state_dict(checkpoint_file['optim'])
    optimD.load_state_dict(checkpoint_file['optimD'])
    plot_array = checkpoint_file['plot_array']
    plot_arrayD = checkpoint_file['plot_arrayD']
    start_epoch = checkpoint_file['epoch']
    print('Checkpoint loaded!')

def plot_loss():
    fig, axe = plt.subplots(1, 2, figsize=(16,9))
    axe[0].plot(range(1, 1+len(plot_array)), plot_array, linestyle='-', marker='o', color='b', label='Loss')
    axe[0].plot(range(1, 1+len(plot_array)), plot_arrayD, linestyle='-', marker='o', color='r', label='LossD')
    axe[0].set_xlabel('Epoch')
    axe[0].set_ylabel('Loss')
    axe[0].grid(True)
    axe[0].legend()

    axe[1].plot(range(1, 1+len(plot_array)), plot_array, linestyle='-', marker='o', color='b', label='Loss')
    axe[1].plot(range(1, 1+len(plot_array)), plot_arrayD, linestyle='-', marker='o', color='r', label='LossD')
    axe[1].set_xlabel('Epoch')
    axe[1].set_yscale('log')
    axe[1].grid(True)
    axe[1].legend()
    plt.show()
    return

def train(epochs=50):
    for epoch in range(start_epoch+1, start_epoch+1+epochs):
        print('Epoch:', epoch)
        loss_sum = 0
        loss_sumD = 0
        for i, data in enumerate(tqdm(dataloader)):
            inputs = data['input'].to(device)
            label = data['label'].to(device)
            mask = data['mask'].to(device)

            logits = model(torch.cat([inputs, mask], dim=1))
            logits = torch.where(mask>0.95, logits, label)
            outputD_fake = modelD(logits)
            loss = 10*lossfn(outputD_fake, torch.ones_like(outputD_fake)) + 2*vggloss(logits, label) + lossfn(logits, label)
            optim.zero_grad()
            loss.backward()
            loss_sum += loss.item()
            optim.step()

            outputD_real = modelD(label)
            outputD_fake = modelD(logits.detach())
            lossD = lossfn(outputD_fake, torch.zeros_like(outputD_fake)) + lossfn(outputD_real, torch.ones_like(outputD_real))
            optimD.zero_grad()
            lossD.backward()
            loss_sumD += lossD.item()
            optimD.step()

            if (i+1)%200==0:
                print('Loss:', loss_sum, 'LossD:', loss_sumD)
                grid = torchvision.utils.make_grid([
                    inputs[0],
                    label[0],
                    logits[0].detach()], 
                    3
                )
                writer.add_image('Images', grid, epoch)
                torchvision.utils.save_image(
                    grid,
                    f'TrainImg/Epoch{epoch}_{i+1}.png'
                )

        plot_array.append(loss_sum)
        plot_arrayD.append(loss_sumD)
        writer.add_scalar('Loss', loss_sum, epoch)

        if epoch % SAVE_EVERY_EPOCH == 0:
            torch.save({
                'model': model.state_dict(),
                'modelD': modelD.state_dict(),
                'optim': optim.state_dict(),
                'optimD': optimD.state_dict(),
                'plot_array': plot_array,
                'plot_arrayD': plot_arrayD,
                'epoch': epoch
            }, 'TrainCheckpoint/'+model.__class__.__name__+'CheckpointLast.pth')
            print('Checkpoint saved!')
    return

train(200)