In [3]:
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from dataset import TorinoAquaDataset
from model import Unet, UnetReconstruct
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm.auto import tqdm
from utils import DiceLoss, MaskL1Loss

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if os.path.exists('modelUnet.pth') and os.path.exists('modelUnetReconstruct.pth'):
    modelUnet = Unet(0.5).to(device)
    modelUnet.load_state_dict(torch.load('modelUnet.pth'))
    modelUnetReconstruct = UnetReconstruct(0.75).to(device)
    modelUnetReconstruct.load_state_dict(torch.load('modelUnetReconstruct.pth'))
    dataset = TorinoAquaDataset()
    dataloader = DataLoader(dataset, 1, True)
    for i in range(10):
        with torch.inference_mode():
            data = next(iter(dataloader))['input'].to(device)
            out1 = modelUnet(data)
            out2 = modelUnetReconstruct(data)
            grid = torchvision.utils.make_grid([
                    data[0],
                    out1[0].repeat(3, 1, 1),
                    out2[0],
                    torch.where(out1[0] > 0.95, out2[0], data[0])], 
                    2
                )
            os.makedirs('TrainImg', exist_ok=True)
            torchvision.utils.save_image(
                grid,
                f'TrainImg/Eval{i+1}.png'
            )