# Looking at mask predictions

Running validation data through saved model to look at masks

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import sys
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

sys.path.append('../src')
from UNet2D import UNet2D
from UNetMultiTask import UNetMultiTask
from datasets_only_segmentation import MycetomaDataset
from metrics import batch_dice_coeff, bce_dice_loss, dice_coefficient
from postprocessing import threshold_mask, post_process_binary_mask

In [None]:
DATA_DIR = '../data'

In [None]:
# Get test data now
test_paths = np.array([os.path.relpath(i, DATA_DIR).split('.')[0] for i in glob.glob(f'{DATA_DIR}/test_dataset/*.jpg')])

print(f"Test length: {len(test_paths)}")

In [None]:
# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
test_dataset = MycetomaDataset(test_paths, DATA_DIR, test_flag=True)

In [None]:
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
# Plot an image and prediction
def plot_image(im, pred):

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    ax[0].imshow(im)
    ax[0].set_title('Original Image')
    ax[0].axis('off')

    ax[1].imshow(pred)
    ax[1].set_title('prediction')
    ax[1].axis('off')

    plt.show()

In [None]:
# Create and load model save
model = UNetMultiTask(3, 1, 8)
state_dict = torch.load('../model_saves/train_bs_32_lr_5e-4_lw_0.2_best_model.pth', map_location=torch.device(device))

# Sometimes, the model dictionary keys contain 'module.' prefix which we don't want
remove_prefix = True

if remove_prefix:
    remove_prefix = 'module.'
    state_dict = {k[len(remove_prefix):] if k.startswith(remove_prefix) else k: v for k, v in state_dict.items()}

model.load_state_dict(state_dict)
model = model.to(device)

model.eval();

In [None]:
# Put validation data through, plotting image, prediction, ground truth each time

threshold = 0.5
n=0

# Perform loop without computing gradients
with torch.no_grad():
    for idx, (inputs) in enumerate(test_loader):
        inputs = inputs.to(device)

        outputs, class_out = model(inputs)

        im = inputs[0].detach().cpu().permute(1,2,0).numpy()
        pred = threshold_mask(outputs[0][0].detach().cpu().numpy())

        # Post-process mask
        post_proc_mask = np.clip(post_process_binary_mask(pred, threshold_fraction=0.05), 0, 1)

        print("\n-------------------------------------------")
        print(f"Test image {test_paths[n]}")
        # Plot prediction
        plot_image(im, pred)

        class_prob = class_out.squeeze().item()
        class_pred = "BM" if class_prob > threshold else "FM"

        print(f"Classification prediction: {class_prob} -> {class_pred}" )

        n += 1