# Results visualisations

* plot metrics for each image and metrics distribution
* plot correlation between true and predicted lesion volume and count
* show example predictions and uncertainty maps

In [None]:
import os
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
from matplotlib import gridspec
import seaborn as sns
sns.set(style='white', font_scale=1.5)
import nibabel as nib
import re

from monai.transforms import CropForeground, Crop
from monai.visualize.utils import blend_images

In [None]:
model_name = 'xunet-loss-ce-lr'

In [None]:
ndsc, f1, aac = [], [], []
for set_ in ['dev_in', 'eval_in', 'dev_out']:
    ndsc.append(np.load(os.path.join('..', 'test_predictions', model_name, set_, 'nDSC.npy')))
    f1.append(np.load(os.path.join('..', 'test_predictions', model_name, set_, 'f1.npy')))
    aac.append(np.load(os.path.join('..', 'test_predictions', model_name, set_, 'nDSC_R-AUC.npy')))

## Metrics distribution

In [None]:
def plot_distribution(data, metric_name, plot_title):
    plt.title(plot_title)
    sns.boxplot(data=data)
    plt.xticks(range(3), ['dev_in', 'eval_in', 'dev_out'])
    plt.ylabel(metric_name)
    plt.ylim(0,1)
    
plt.figure(figsize=(30,9))
plt.subplot(131)
plot_distribution(ndsc, 'nDSC', 'nDSC distribution')

plt.subplot(132)
plot_distribution(f1, 'lesion f1', 'Lesion f1 distribution')

plt.subplot(133)
plot_distribution(aac, 'nDSC R-AUC', 'nDSC R-AUC distribution')

## Metrics for each image

In [None]:
for metric, metric_name in zip([ndsc, f1, aac], ['nDSC', 'lesion f1', 'nDSC R-AUC']):
    for i, set_ in enumerate(['dev_in', 'eval_in', 'dev_out']):
        plt.figure(figsize=(20,7))
        sns.barplot(y=metric[i], x=list(range(1, metric[i].shape[0]+1)))
        plt.ylim(0,1)
        plt.ylabel(metric_name)
        plt.xlabel('Image')
        plt.title(f'{metric_name} for each image - {set_}');

## Correlation

In [None]:
def plot_correlation(value, unit, extra_lim=100):
    for set_ in ['dev_in', 'eval_in', 'dev_out']:
        gt = np.load(os.path.join('..', 'test_predictions', model_name, set_, 'gt_' + value + '.npy'))
        pred = np.load(os.path.join('..', 'test_predictions', model_name, set_, 'pred_' + value + '.npy'))

        plt.figure(figsize=(15,9))
        sns.regplot(x=pred, y=gt, truncate=False)
        plt.xlabel(f'Pred {value} {unit}')
        plt.ylabel(f'GT {value} {unit}')
        plt.xlim(0, np.max(pred)+extra_lim)
        plt.ylim(0, np.max(gt)+extra_lim)
        plt.title(f'Correlation between ground truth and segmented lesion {value} - ' + set_)
        plt.plot([0,np.max(pred)+extra_lim],[0,np.max(gt)+extra_lim], linestyle='dashed', label='GT')
        plt.legend();

In [None]:
plot_correlation('volume', '[voxels]')

In [None]:
plot_correlation('count', '', extra_lim=10)

## Example predictions

In [None]:
crop_foreground = CropForeground()
crop = Crop()

In [None]:
def get_paths(model_name, subset):
    imgs = sorted(glob(f'../data/**/{subset}/flair/*.nii.gz', recursive=True), key=lambda i: int(re.sub('\D', '', i)))
    gts = sorted(glob(f'../data/**/{subset}/gt/*.nii.gz', recursive=True), key=lambda i: int(re.sub('\D', '', i)))
    preds = sorted(glob(f'../test_predictions/{model_name}/{subset}/predictions/*pred_seg.nii.gz', recursive=True), key=lambda i: int(re.sub('\D', '', i)))
    pred_probs = sorted(glob(f'../test_predictions/{model_name}/{subset}/predictions/*pred_prob.nii.gz', recursive=True), key=lambda i: int(re.sub('\D', '', i)))
    pred_uncs = sorted(glob(f'../test_predictions/{model_name}/{subset}/predictions/*uncs.nii.gz', recursive=True), key=lambda i: int(re.sub('\D', '', i)))
    
    return imgs, gts, preds, pred_probs, pred_uncs


def plot_example_predictions(img_num, imgs, gts, preds, pred_probs, pred_uncs, title, slice_num=100):
    img = nib.load(imgs[img_num-1]).get_fdata()
    gt = nib.load(gts[img_num-1]).get_fdata()
    pred = nib.load(preds[img_num-1]).get_fdata()
    # prob = nib.load(pred_probs[img_num-1]).get_fdata()
    uncs = nib.load(pred_uncs[img_num-1]).get_fdata()
    
    img = np.expand_dims(img, 0)
    gt = np.expand_dims(gt, 0)
    pred = np.expand_dims(pred, 0)
    # prob = np.expand_dims(prob, 0)
    uncs = np.expand_dims(uncs, 0)
    
    img_gt = blend_images(img, gt, alpha=0.5, cmap='Greens')
    img_pred = blend_images(img, pred, alpha=0.5, cmap='summer')
    
    bbox = crop_foreground.compute_bounding_box(img_gt)

    img_gt = crop_foreground(img_gt)
    
    slices = crop.compute_slices(roi_start=bbox[0], roi_end=bbox[1])
    img_pred = crop(img_pred, slices)
    # prob = crop(prob, slices)
    uncs = crop(uncs, slices)
    
    img_gt = np.transpose(img_gt, (1,2,3,0))
    img_pred = np.transpose(img_pred, (1,2,3,0))
    # prob = np.transpose(prob, (1,2,3,0))
    uncs = np.transpose(uncs, (1,2,3,0))
    
    plt.subplot(131)
    plt.imshow(img_gt[:,:,slice_num,:])
    plt.gca().set_yticklabels([])
    plt.gca().set_xticklabels([])
    plt.gca().set_xticks([])
    plt.gca().set_yticks([])
    plt.ylabel(title, fontsize=35)
    
    plt.subplot(132)
    plt.imshow(img_pred[:,:,slice_num,:])
    plt.axis('off')
    
    plt.subplot(133)
    plt.imshow(uncs[:,:,slice_num,:], vmin=0, vmax=1)
    plt.colorbar(orientation='horizontal', pad=0.01)
    plt.axis('off')
    
    plt.tight_layout()
    # plt.savefig('preds_example.png')

In [None]:
plt.figure(figsize=(30,12))
imgs, gts, preds, pred_probs, pred_uncs = get_paths(model_name, 'dev_in')
plot_example_predictions(1, imgs, gts, preds, pred_probs, pred_uncs, 'dev_in', slice_num=85)

In [None]:
plt.figure(figsize=(30,10))
imgs, gts, preds, pred_probs, pred_uncs = get_paths(model_name, 'eval_in')
plot_example_predictions(1, imgs, gts, preds, pred_probs, pred_uncs, 'eval_in', slice_num=85)

In [None]:
plt.figure(figsize=(30,11))
imgs, gts, preds, pred_probs, pred_uncs = get_paths(model_name, 'dev_out')
plot_example_predictions(1, imgs, gts, preds, pred_probs, pred_uncs, 'dev_out', slice_num=85)