# Looking at mask predictions

Running validation data through saved model to look at masks

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import sys
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

sys.path.append('../src')
from UNet2D import UNet2D
from UNetMultiTask import UNetMultiTask
from datasets import MycetomaDataset
from metrics import batch_dice_coeff, bce_dice_loss, dice_coefficient
from postprocessing import threshold_mask, post_process_binary_mask

In [None]:
DATA_DIR = '../data'

In [None]:
train_paths = np.array([os.path.relpath(i, DATA_DIR).split('.')[0] for i in glob.glob(f'{DATA_DIR}/corrected_training_dataset/**/*.jpg')])
val_paths = np.array([os.path.relpath(i, DATA_DIR).split('.')[0] for i in glob.glob(f'{DATA_DIR}/corrected_validation_dataset/**/*.jpg')])

problem_val_paths = np.array(['corrected_validation_dataset/FM/FM10_1'])
val_paths = np.setdiff1d(val_paths, problem_val_paths)

print(f"Train length: {len(train_paths)}")
print(f"Val length: {len(val_paths)}")

In [None]:
# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
train_dataset = MycetomaDataset(train_paths, DATA_DIR)
val_dataset = MycetomaDataset(val_paths, DATA_DIR)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [None]:
# Plot an image, along with prediction and ground truth
def plot_image(im, pred, gt):

    fig, ax = plt.subplots(1, 3, figsize=(10, 5))

    ax[0].imshow(im)
    ax[0].set_title('Original Image')
    ax[0].axis('off')

    ax[1].imshow(pred)
    ax[1].set_title('prediction')
    ax[1].axis('off')

    ax[2].imshow(gt)
    ax[2].set_title('GT')
    ax[2].axis('off')

    plt.show()

In [None]:
# Create and load model save
model = UNetMultiTask(3, 1, 8)
state_dict = torch.load('../model_saves/more_dropout_jitter_bs_32_lr_1e-3_lw_0.2_best_model.pth', map_location=torch.device(device))

# Sometimes, the model dictionary keys contain 'module.' prefix which we don't want
remove_prefix = True

if remove_prefix:
    remove_prefix = 'module.'
    state_dict = {k[len(remove_prefix):] if k.startswith(remove_prefix) else k: v for k, v in state_dict.items()}

model.load_state_dict(state_dict)
model = model.to(device)

model.eval();

In [None]:
# Put validation data through, plotting image, postproc prediction, ground truth each time
from tqdm import tqdm

threshold = 0.5
dice_coeff = 0.0
post_dice_coeff = 0.0
gts = []
preds = []
n = 0

# Perform loop without computing gradients
with torch.no_grad():
    for idx, (inputs, targets, labels) in enumerate(val_loader):
        
        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs, class_out = model(inputs)

        dice_coeff += batch_dice_coeff(outputs>threshold, targets).detach().cpu().numpy()
        n += 1

        im = inputs[0].detach().cpu().permute(1,2,0).numpy()
        pred = threshold_mask(outputs[0][0].detach().cpu().numpy())
        gt = targets[0][0].detach().cpu().numpy()

        #plot_image(im, pred, gt)
        dice = dice_coefficient(torch.from_numpy(pred).float(), torch.from_numpy(gt).float())

        # Post-process mask
        post_proc_mask = np.clip(post_process_binary_mask(pred, threshold_fraction=0.05), 0, 1)

        post_proc_dice = dice_coefficient(torch.from_numpy(post_proc_mask).float(), torch.from_numpy(gt).float())

        post_dice_coeff += post_proc_dice

        gts.append(labels.item())
        preds.append(class_out.squeeze().item())
        
        # Plot prediction before and after processing
        fig, ax = plt.subplots(1, 3, figsize=(10, 5))
        ax[0].imshow(im)
        ax[0].set_title('Image')
        ax[0].axis('off')

        ax[1].imshow(post_proc_mask)
        ax[1].set_title('Post-Proc Mask')
        ax[1].axis('off')

        ax[2].imshow(gt)
        ax[2].set_title('GT')
        ax[2].axis('off')

        plt.show()

        print(f"Classification prediction: {class_out.squeeze().item()}, GT: {labels.item()}")
        print(f"Dice score before postproc: {dice} vs after: {post_proc_dice}")

pre_proc_dice_av = dice_coeff/n
post_proc_dice_av = post_dice_coeff/n
print("Av. dice score before preproc: ", pre_proc_dice_av, "vs post: ", post_proc_dice_av)

In [None]:
preds_binary = [1 if pred > 0.5 else 0 for pred in preds]

In [None]:
from metrics import accuracy
accuracy(torch.tensor(preds), torch.tensor(gts))

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(gts, preds_binary)

# Normalize the confusion matrix
#cm_normalised = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

In [None]:
cm_displayed = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Fungal", "Bacterial"])

# Plot the confusion matrix
cm_displayed.plot(cmap=plt.cm.Blues)

#plt.savefig('cm_multitask_firstgo.png')

plt.show()