In [None]:
from dataset import TorinoAquaDataset
from models.Unet import UResNet
import os
import torch
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from utils import VGGLoss
import torchvision
from torch.utils.tensorboard import SummaryWriter

LOAD_CHECKPOINT = True
SAVE_EVERY_EPOCH = 1

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UResNet(0.5).to(device)
lossfn = VGGLoss(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
plot_array = []
loss_sum_min = 2e9
start_epoch = 0

dataset = TorinoAquaDataset(no_mask=-1)
dataloader = DataLoader(dataset, 1, True, num_workers=1)

writer = SummaryWriter()

os.makedirs('TrainImg', exist_ok=True)
os.makedirs('TrainCheckpoint', exist_ok=True)

if LOAD_CHECKPOINT and os.path.exists('TrainCheckpoint/'+model.__class__.__name__+'CheckpointLast.pth'):
    checkpoint_file = torch.load('TrainCheckpoint/'+model.__class__.__name__+'CheckpointLast.pth')
    model.load_state_dict(checkpoint_file['model'])
    optim.load_state_dict(checkpoint_file['optim'])
    loss_sum_min = checkpoint_file['loss_sum_min']
    plot_array = checkpoint_file['plot_array']
    start_epoch = checkpoint_file['epoch']
    print('Checkpoint loaded!')

def plot_loss():
    fig, axe = plt.subplots(1, 2, figsize=(16,9))
    axe[0].plot(range(1, 1+len(plot_array)),plot_array, linestyle='-', marker='o', color='b', label='Loss')
    axe[0].set_xlabel('Epoch')
    axe[0].set_ylabel('Loss')
    axe[0].grid(True)
    axe[0].legend()

    axe[1].plot(range(1, 1+len(plot_array)),plot_array, linestyle='-', marker='o', color='b', label='Loss')
    axe[1].set_xlabel('Epoch')
    axe[1].set_yscale('log')
    axe[1].grid(True)
    axe[1].legend()
    plt.show()
    return

def train(epochs=50):
    global loss_sum_min
    for epoch in range(start_epoch+1, start_epoch+1+epochs):
        print('Epoch:', epoch)
        loss_sum = 0
        for i, data in enumerate(tqdm(dataloader)):
            inputs = data['input'].to(device)
            label = data['label'].to(device)
            logits = model(inputs)
            loss = lossfn(logits, label)
            optim.zero_grad()
            loss.backward()
            loss_sum += loss.item()
            optim.step()
            if (i+1)%200==0:
                print(loss.item())
                grid = torchvision.utils.make_grid([
                    inputs[0],
                    label[0],
                    logits[0].detach()], 
                    3
                )
                writer.add_image('Images', grid, epoch)
                torchvision.utils.save_image(
                    grid,
                    f'TrainImg/Epoch{epoch}_{i+1}.png'
                )

        plot_array.append(loss_sum)
        writer.add_scalar('Loss', loss_sum, epoch)

        if loss_sum<loss_sum_min:
            loss_sum_min = loss_sum
            torch.save({
                'model': model.state_dict(),
                'optim': optim.state_dict(),
                'loss_sum_min': loss_sum_min,
                'plot_array': plot_array,
                'epoch': epoch
            }, 'TrainCheckpoint/'+model.__class__.__name__+'CheckpointBest.pth')
            print('Best model checkpoint saved!')

        if epoch % SAVE_EVERY_EPOCH == 0:
            torch.save({
                'model': model.state_dict(),
                'optim': optim.state_dict(),
                'loss_sum_min': loss_sum_min,
                'plot_array': plot_array,
                'epoch': epoch
            }, 'TrainCheckpoint/'+model.__class__.__name__+'CheckpointLast.pth')
            print('Checkpoint saved!')
    return

In [None]:
train(100)