# Visualization of some images (resized both to the same size so that they can be masked)


In [24]:
from oxford_pets_dataset import OxfordPetsDataset
from torchvision.transforms import Compose, Resize, CenterCrop, Pad
import matplotlib.pyplot as plt
from unet import UNet
import torch
from train import get_resize_transform


unet = UNet(in_channels=3, out_channels=1).cuda()
# try to load the model
try:
    unet.load_state_dict(torch.load("checkpoints/unet.pth"))
except FileNotFoundError:
    print("Model not found, please train it first")


# We need to transform it to be the same size
dataset = OxfordPetsDataset(
    "data",
    mode="test",
    transform={
        "image": Compose([get_resize_transform((572, 572))]),
        "mask": Compose([get_resize_transform((572, 572)), CenterCrop((388, 388))]),
    },
)


def compare(N):
    # get random image and mask
    image, mask = dataset[N]
    plt.imshow(image.permute(1, 2, 0))
    plt.imshow(mask.squeeze(), alpha=0.5, cmap="gray")

    # see what the model predicts
    unet.eval()
    with torch.no_grad():
        pred = unet(image.unsqueeze(0).cuda())
    plt.imshow(CenterCrop(378)(image).permute(1, 2, 0))
    # transform the mask to the image size
    plt.imshow((mask.squeeze().cpu() > 0.5), alpha=0.5, cmap="gray")

    plt.figure()
    plt.imshow(pred.squeeze().cpu() > 0.5, alpha=1, cmap="gray")

    plt.show()


compare(8)
