# Predictions

## Imports

In [1]:
# load custom scripts
import config

# import the necessary packages
from skimage import io
from torchvision import transforms

import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torchvision.transforms.functional as TF

## Plot function

In [2]:
def prepare_plot(image, gtMask, predMask):
    # initialize our figure
    figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
    
    unnormalize = transforms.Compose([transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                  transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ])])

    # plot the original image, its mask, and the predicted mask
    ax[0].imshow(unnormalize(image[0]).cpu().detach().permute(1, 2, 0))
    ax[1].imshow(gtMask.numpy(), cmap = "tab20", vmin=0, vmax=255)
    ax[2].imshow(predMask, cmap = "tab20", vmin=0, vmax=255)
    
    # set the titles of the subplots
    ax[0].set_title("Image")
    ax[1].set_title("Original Mask")
    ax[2].set_title("Predicted Mask")
    
    # set the layout of the figure and display it
    figure.tight_layout()
    figure.show()

## Testing loop

In [3]:
def make_predictions(model, imagePath):
    # set model to evaluation mode
    model.eval()
    
    # turn off gradient tracking
    with torch.no_grad():
        # load the image from disk, cast it to float data type, 
        # and scale its pixel values
        image = io.imread(imagePath)
        image = image.astype("float32") / 255.0
        
        # make the channel axis to be the leading one, add a batch
        # dimension and create a PyTorch tensor
        image = np.transpose(image, (2, 0, 1))
        image = np.expand_dims(image, 0)
        image = torch.from_numpy(image)
        
        # random crop
        crop = transforms.RandomResizedCrop(128)
        i, j, h, w = crop.get_params(image, scale=(0.08, 1.0), ratio=(0.75, 1.33))
        # apply crop to the image 
        image = TF.resized_crop(image, i, j, h, w, (128, 128), transforms.InterpolationMode.BILINEAR)
        
        # normalize the image
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        image = normalize(image)
        
        # flash to device
        image = image.to(config.DEVICE)
        
        # find the filename and generate the path to ground truth mask
        filename = imagePath.split(os.path.sep)[-1].replace("JPG","png")
        groundTruthPath = os.path.join(config.MASK_DATASET_PATH, filename)
        
        # load the ground-truth segmentation mask, make a tensor 
        # and crop it
        gtMask = io.imread(groundTruthPath)
        asarray = lambda x: torch.tensor(np.array(x), dtype=torch.long)
        gtMask = asarray(np.expand_dims(gtMask, 0))
        gtMask = TF.resized_crop(gtMask, i, j, h, w, (128, 128),
                                 transforms.InterpolationMode.BILINEAR).squeeze()
    
        
        # make the prediction, pass the results through the argmax
        # function, and convert the result to a NumPy array
        predMask = model(image).squeeze()
        predMask = torch.argmax(predMask, 0)
        predMask = predMask.cpu().numpy()
        
        # convert the predictions to integers
        predMask = predMask.astype(np.uint8)
        
        # prepare a plot for visualization
        prepare_plot(image, gtMask, predMask)

## Make predictions

In [None]:
# load the image paths in our testing file and randomly select 10 image paths
print("[INFO] loading up test image paths...")
imagePaths = open(config.TEST_PATHS).read().strip().split("\n")
imagePaths = np.random.choice(imagePaths, size=10)

# load our model from disk and flash it to the current device
print("[INFO] load up model...")
unet = torch.load(config.MODEL_PATH).to(config.DEVICE)

# iterate over the randomly selected test image paths
for path in imagePaths:
    make_predictions(unet, path)