In [None]:
from dataset import RawImageDataset
from models.Unet import Unet
from models.Lama import LamaUnet
import os
import torch
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from utils import MaskL1Loss, DiceLoss, VGGLoss
import torchvision
from torch.utils.tensorboard import SummaryWriter

LOAD_CHECKPOINT = True

device = 'cuda' if torch.cuda.is_available() else 'cpu'
modelUnet = Unet(0.5).to(device)
modelDecen = LamaUnet(1).to(device)

dataset = RawImageDataset()
dataloader = DataLoader(dataset, 1, False, num_workers=1)

if LOAD_CHECKPOINT:
    checkpoint_file = torch.load('TrainCheckpoint/UnetCheckpointBest.pth')
    modelUnet.load_state_dict(checkpoint_file['model'])
    checkpoint_file = torch.load('TrainCheckpoint/LamaUnetCheckpointLast30epoch.pth')
    modelDecen.load_state_dict(checkpoint_file['model'])
    print('Checkpoint loaded!')

In [None]:
data = next(iter(dataloader))
with torch.inference_mode():
    def decen(modelUnet, modelDecen, data):
        data = data.to(device)
        a512 = 512
        def inference_unet(model, data):
            H, W = data.shape[2:]
            patch_size = 512
            output = torch.zeros_like(data).to(device)
            for i in range(H//patch_size+1):
                for j in range(W//patch_size+1):
                    x, y, z, t = i*patch_size, i*patch_size+a512, j*patch_size, j*patch_size+a512
                    if y > H:
                        y = H
                        x = H - a512
                    if t > W:
                        t = W
                        z = W - a512
                    inputs = data[:, :, x:y, z:t].to(device)
                    logits = model(inputs)
                    logits = torch.where(logits > 0.70, torch.ones_like(logits), torch.zeros_like(logits))
                    output[:, :, x:y, z:t] = torch.max(output[:, :, x:y, z:t], logits)
            return output
        def inference_decen(model, data):
            H, W = data.shape[2:]
            patch_size = 512
            output = torch.zeros_like(data).to(device)
            for i in range(H//patch_size+1):
                for j in range(W//patch_size+1):
                    x, y, z, t = i*patch_size, i*patch_size+a512, j*patch_size, j*patch_size+a512
                    if y > H:
                        y = H
                        x = H - a512
                    if t > W:
                        t = W
                        z = W - a512
                    inputs = data[:, :, x:y, z:t].to(device)
                    logits = model(inputs)
                    output[:, :, x:y, z:t] = torch.where(output[:, :, x:y, z:t]==0, logits, output[:, :, x:y, z:t])
                    output[:, :, x:y, z:t] = (output[:, :, x:y, z:t] + logits) / 2
            return output
        output_seg = inference_unet(modelUnet, data)
        torchvision.utils.save_image(output_seg, 'TrainImg/inferenceSeg.png')
        torch.cuda.empty_cache()
        output_decen = inference_decen(modelDecen, data)
        torchvision.utils.save_image(output_decen, 'TrainImg/inferenceDecen.png')
        torch.cuda.empty_cache()
        output = torch.where(output_seg == 1, output_decen, data)
        torchvision.utils.save_image(output, 'TrainImg/inferenceFinal.png')
        return
    decen(modelUnet, modelDecen, data)