In [None]:
import os
import sys

import torch
from torch.utils.data import DataLoader, random_split

import logging
from pathlib import Path
from PIL import Image
import numpy as np
from unet.unet_model_xB import UNet
import matplotlib.pyplot as plt
from utils.petsReconDataset_multiloss_pl import PetsReconDataset

### Set device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

### Set seed

In [None]:
manual_seed = 0
torch.manual_seed(manual_seed)

### Load Model

Load weights of the model

In [None]:
model_path = '../pets_final/multiloss/CP_Trial22_Epoch60.pth'
# model_path = 'checkpoints/pascalVOC/multiloss/04-30/17-25-10/CP_epoch2.pth'

In [None]:
#Enter the correct arguments for the UNet
net = UNet(n_channels=3, n_classes=1, bilinear=True)

net.load_state_dict(
            torch.load(model_path, map_location=device)
        )
logging.info(f'Model loaded from {model_path}')

### Load dataset

In [None]:
root_dir = Path().resolve().parent
print(root_dir)

batch_size=4

In [None]:
def get_dataloaders(root_dir,
                    val_percent=0.1):

    global n_train, n_val

#     root_dir = args.rd
    dir_img = os.path.join(root_dir, 'Datasets/petsData/images/')
    dir_mask = os.path.join(root_dir, 'Datasets/petsData/annotations/trimaps/')


    dataset = PetsReconDataset(dir_img, dir_mask, None, 224)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers = 2)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers = 2, pin_memory=True)

    return train_loader, val_loader

In [None]:
_, val_loader = get_dataloaders(root_dir)

In [None]:
def torchToPIL_img(img):
    
    img = img.squeeze().cpu().numpy()
    img = img.transpose((1,2,0))
    return Image.fromarray((img * 255).astype(np.uint8), 'RGB')

In [None]:
def torchToPIL_mask(mask):
    
    mask = mask.squeeze().cpu().numpy()
    mask = np.clip(mask, 0, 1)
    return Image.fromarray((mask * 255).astype(np.uint8), 'L')

In [None]:
# batch = next(iter(val_loader))

In [None]:
# display(torchToPIL_img(batch['image'][0]), torchToPIL_mask(batch['mask'][0]))

In [None]:
# display(torchToPIL_mask(batch['mask'][0]))

In [None]:
j = 0
for batch in val_loader:
    print(j)
    j += 1
    for i in range(batch_size):
        print(batch['image_ID'][i])
        print(i)
        
#         pred_recon_img, pred_mask = net(img)
        
        display(torchToPIL_img(batch['image'][i]), torchToPIL_mask(batch['mask'][i]))