# Step 3
## Autocontouring-specific visualisation

Autocontouring is a pretty big deal in any field that uses medical imaging.

This notebook demonstrates:
- a basic autocontouring algorithm
- to visually the performance of the algorithm
- some other cool stuff!

In [None]:
"""
Import some useful modules
"""

from pathlib import Path

import numpy as np
import SimpleITK as sitk

import matplotlib.pyplot as plt

from platipy.imaging import ImageVisualiser
from platipy.dicom.io.crawl import process_dicom_directory
from platipy.imaging.label.utils import get_com
from platipy.imaging.utils.crop import crop_to_label_extent

from platipy.imaging.registration.linear import linear_registration
from platipy.imaging.registration.deformable import fast_symmetric_forces_demons_registration
from platipy.imaging.registration.utils import apply_transform

from platipy.imaging.visualisation.comparison import contour_comparison

from platipy.imaging.generation.dvf import generate_field_asymmetric_extend

import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

import seaborn as sns

%matplotlib inline

In [None]:
"""
We are going to do very simple atlas-based segmentation
Let's start by loading in some data
"""

target_dir = Path("./input/NIfTI/RTMAC_LIVE_003/")
target = {}
target["MRI"] = sitk.ReadImage( str(list(target_dir.glob("IMAGES/*.nii.gz"))[0]) , sitk.sitkUInt32)
target["LABELS"] = {}
for s_file in target_dir.glob("STRUCTURES/*.nii.gz"):
    target["LABELS"][s_file.name[26:-7]] = sitk.ReadImage( str(s_file) )
    

atlas_dir = Path("./input/NIfTI/RTMAC_LIVE_002/")
atlas = {}
atlas["MRI"] = sitk.ReadImage( str(list(atlas_dir.glob("IMAGES/*.nii.gz"))[0]) , sitk.sitkUInt32)
atlas["LABELS"] = {}
for s_file in atlas_dir.glob("STRUCTURES/*.nii.gz"):
    atlas["LABELS"][s_file.name[26:-7]] = sitk.ReadImage( str(s_file) )

In [None]:
"""
Now let's quickly check what out data look like
"""

vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_contour(target["LABELS"], linestyle='solid')
fig=vis.show()

In [None]:
vis = ImageVisualiser(atlas["MRI"], window=(0,400), figure_size_in=6)
vis.add_contour(atlas["LABELS"], linestyle='dashed')
fig=vis.show()

In [None]:
"""
Register the images

Step 1 is a linear transformation

For inter-patient registration, an affine transform is a good option
"""

atlas_mri_affine, tfm_affine = linear_registration(
    fixed_image = target["MRI"],
    moving_image = atlas["MRI"],
    reg_method='affine',
    metric='mean_squares',
    optimiser='gradient_descent',
    shrink_factors=[8, 4, 2],
    smooth_sigmas=[4, 2, 0],
    sampling_rate=0.5,
    final_interp=2,
    number_of_iterations=50,
    default_value=0,
    verbose=False,
)

### Comparing two images

A very common challenge!

Here used to check on a registration process.

This **is not** just displaying a pink image and a green image!

In [None]:
vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_comparison_overlay(atlas_mri_affine)
fig=vis.show()

In [None]:
fig.savefig("./figures/mri_coreg_affine.jpeg", dpi=300)

In [None]:
"""
For the next step we are going to use DIR
"""

atlas_mri_dir, tfm_dir, dvf = fast_symmetric_forces_demons_registration(
    fixed_image = target["MRI"],
    moving_image = atlas_mri_affine,
    resolution_staging=[8, 4, 1],
    iteration_staging=[70,50,30],
    isotropic_resample=True,
    smoothing_sigma_factor=1,
    default_value=0,
    ncores=8,
    interp_order=2,
    verbose=False,
)

In [None]:
vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_comparison_overlay(atlas_mri_dir)
fig=vis.show()

In [None]:
fig.savefig("./figures/mri_coreg_dir.jpeg", dpi=300)

In [None]:
"""
We are often interested in the deformation vector field (DVF)
"""

vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_vector_overlay(
    dvf,
    name="DVF magnitude [mm]",
    colormap=plt.cm.viridis,
    alpha=0.75,
    arrow_scale=1,
    arrow_width=1,
    subsample=(4,16,16),
    color_function='magnitude',
    show_colorbar=True,
)
fig=vis.show()

#### This doesn't seem very informative!

I agree! The deformation outside the patient dominates.

We don't really care about this.

So let's mask it out.

In [None]:
"""
A super simple algorithm to generate an external contour
"""

external_contour = target["MRI"]>50
external_contour = sitk.RelabelComponent(sitk.ConnectedComponent(external_contour))==1
external_contour = sitk.BinaryMorphologicalClosing(external_contour, (10,10,10))

In [None]:
"""
Quick visualisation to make sure our mask looks okay
"""

vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_scalar_overlay(external_contour, show_colorbar=False)
fig=vis.show()

In [None]:
"""
Visualise the deformation again, now only inside the patient!
"""

vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_vector_overlay(
    sitk.Mask(dvf, external_contour),
    name="DVF perpendicular component [mm]",
    colormap=plt.cm.bwr,
    alpha=0.75,
    arrow_scale=2,
    arrow_width=1,
    subsample=(2,8,8),
    color_function='perpendicular',
    show_colorbar=True,
)
fig=vis.show()

In [None]:
fig.savefig("./figures/mri_coreg_dvf_vector_perpendicular.jpeg", dpi=300)

### Alternative representation

We might just want to look at locations where there is a lot of deformation.

These correspond to locations of large differences between the target/atlas.

In [None]:
"""
Calculate the magnitude of the DVF
"""

internal_dvf = sitk.Mask(dvf, external_contour)
internal_dvf_magnitude = sitk.VectorMagnitude(internal_dvf)

In [None]:
"""
We should use a sequential colormap!
"""

vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_scalar_overlay(
    internal_dvf_magnitude,
    name="DVF magnitude [mm]",
    colormap=plt.cm.magma,
    show_colorbar=True,
    max_value=20,
    discrete_levels=10
)
fig=vis.show()

In [None]:
fig.savefig("./figures/mri_coreg_dvf_scalar_magnitude.jpeg", dpi=300)

In [None]:
"""
The Jacobian determinant measures the local relative volume change
"""

internal_jac_det = sitk.DisplacementFieldJacobianDeterminant(internal_dvf)

In [None]:
"""
We would like a colormap that reflects the physical interpretation of the Jac Det
"""

zero_centered_norm = mcolors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=2)

In [None]:
sns.color_palette("icefire", n_colors=12)

In [None]:
"""
Visualise!
"""

vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_scalar_overlay(
    internal_jac_det,
    name="Jacobian Determinant",
    colormap=sns.color_palette("icefire", as_cmap=True),
    show_colorbar=True,
    max_value=2.2,
    min_value=0.0,
    discrete_levels=11,
    alpha=0.75,
    norm=zero_centered_norm
)
fig=vis.show()

In [None]:
fig.savefig("./figures/mri_coreg_dvf_scalar_jab_det_1.jpeg", dpi=300)

In [None]:
"""
The Jac Det should also not be negative
(why is this?)
"""

internal_jac_det_values = sitk.GetArrayFromImage(sitk.DisplacementFieldJacobianDeterminant(internal_dvf))[np.where(sitk.GetArrayFromImage(external_contour))]

In [None]:
"""
Create a nice histogram
"""

histbins = np.linspace(-1,5,600)

fig, ax = plt.subplots(1,1,figsize=(5,4))

counts, bins, bars = ax.hist(internal_jac_det_values, bins=histbins, lw=0, ec="k")

bin_centers = (bins[1:]+bins[:-1])/2

for b,bc in zip(bars.patches, bin_centers):
    
    if bc<=0:
        c="#922b21"
    else:
        c="#212f3d"

    b.set_facecolor(c)
    
ax.set_xlabel("Jacobian Determinant")
ax.set_ylabel("Number of Voxels")

ax.grid()
ax.set_axisbelow(True)

ax.set_yscale("log")

ax.set_xlim(-1,5)

frac_below_zero = (internal_jac_det_values<0).sum()/np.alen(internal_jac_det_values)
handles = [mpatches.Patch(color="#922b21", label=f"Jac. Det. below zero: {100*frac_below_zero:.2f}%")]
ax.legend(handles=handles)

fig.tight_layout()

fig.show()

In [None]:
"""
We could also colour the histogram bins to match up with the displayed image
"""

histbins = np.linspace(-1,5,600)

fig, ax = plt.subplots(1,1,figsize=(6,4))

counts, bins, bars = ax.hist(internal_jac_det_values, bins=histbins, lw=0, ec="k")

bin_centers = (bins[1:]+bins[:-1])/2

cmap = sns.color_palette("icefire", as_cmap=True)
bin_colors = cmap(zero_centered_norm(bin_centers))

for b,c in zip(bars.patches, bin_colors):

    b.set_facecolor(c)
    
ax.set_xlabel("Jacobian Determinant")
ax.set_ylabel("Number of Voxels")

ax.grid()
ax.set_axisbelow(True)

fig.tight_layout()

fig.show()

In [None]:
"""
Since platipy just uses matplotlib we can easily add an axis
Then we can plot anything!
"""

vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_scalar_overlay(
    internal_jac_det,
    name="Jacobian Determinant",
    colormap=sns.color_palette("icefire", as_cmap=True),
    show_colorbar=True,
    max_value=2.2,
    min_value=0.0,
    discrete_levels=11,
    alpha=0.75,
    norm=zero_centered_norm
)
fig=vis.show()

ax = fig.add_axes((0.7, 0.6, 0.25, 0.25))
counts, bins, bars = ax.hist(internal_jac_det_values, bins=histbins, lw=0, ec="k")

bin_centers = (bins[1:]+bins[:-1])/2

for b,bc in zip(bars.patches, bin_centers):
    
    if bc<=0:
        c="#922b21"
    else:
        c="#212f3d"

    b.set_facecolor(c)


ax.grid()
ax.set_axisbelow(True)
ax.set_yscale("log")
ax.set_xlim(-1,5)
frac_below_zero = (internal_jac_det_values<0).sum()/np.alen(internal_jac_det_values)
ax.set_title(f"J<0: {100*frac_below_zero:.2f}%");

In [None]:
fig.savefig("./figures/mri_coreg_dvf_scalar_jab_det_2.jpeg", dpi=300)

In [None]:
"""
This histogram could double as a colorbar!
"""

vis = ImageVisualiser(target["MRI"], window=(0,400), figure_size_in=6)
vis.add_scalar_overlay(
    internal_jac_det,
    name="Jacobian Determinant",
    colormap=sns.color_palette("icefire", as_cmap=True),
    show_colorbar=False,
    max_value=2.2,
    min_value=0.0,
    discrete_levels=11,
    alpha=0.75,
    norm=zero_centered_norm,
)
fig=vis.show()

ax = fig.add_axes((0.6, 0.55, 0.375, 0.425))

counts, bins, bars = ax.hist(internal_jac_det_values, bins=histbins, lw=0, ec="k", orientation="horizontal")

bin_centers = (bins[1:]+bins[:-1])/2

cmap = sns.color_palette("icefire", as_cmap=True)
bin_colors = cmap(zero_centered_norm(bin_centers))

for b,c in zip(bars.patches, bin_colors):

    b.set_facecolor(c)
    
ax.set_ylabel("Jacobian Determinant")

ax.set_xscale("symlog", linthresh=500)
ax.set_xlim(500,2e5)

ax.grid()
ax.set_axisbelow(True)

In [None]:
fig.savefig("./figures/mri_coreg_dvf_scalar_jab_det_3.jpeg", dpi=300)

In [None]:
"""
A bit of a detour, but let's get back to the segmentation problem!
We now need to map across the contours from the atlas
Thankfully SimpleITK makes this really easy
"""

tfm_combined = sitk.CompositeTransform((tfm_affine, tfm_dir))

auto_contours = {}
for s in atlas["LABELS"]:
    auto_contours[s] = apply_transform(
    input_image = atlas["LABELS"][s],
    reference_image=target["MRI"],
    transform=tfm_combined,
    default_value=0,
    interpolator=1,
)

In [None]:
"""
Platipy to the rescue!
"""

fig = contour_comparison(
    img = target["MRI"],
    contour_dict_a = target["LABELS"],
    contour_dict_b = auto_contours,
    contour_label_a='Manual',
    contour_label_b='Auto',
    s_select=sorted(auto_contours.keys()),
    structure_for_com=None,
    structure_for_limits=None,
    title='Atlas-based Segmentation',
    subtitle='H&N Glands',
    subsubtitle='Single atlas\nLog-domain symmetric diffeomorphic DIR algorithm',
    contour_cmap=plt.cm.rainbow,
    structure_name_dict=None,
    img_vis_kw=dict(window=(0,0.9), figure_size_in=8, projection="mean"),
)

In [None]:
fig.savefig("./figures/mri_atlas_results_drr.jpeg", dpi=300)

### Extending visualisation tools

Once great thing about approaching visualisation using code is extensibility.

We can create informative and visually appealing figures.

In [None]:
fig.savefig("./figures/mri_atlas_results.jpeg", dpi=300)

In [None]:
"""
Maybe we are interested in places where our algorithms makes correct/incorrect predictions
So let's visualise those
"""

predictions = {}

for s in target["LABELS"]:
    true_label = target["LABELS"][s]
    pred_label = auto_contours[s]
    
    predictions[s] = {}
    
    predictions[s]["TP"] = (pred_label & true_label)
    predictions[s]["FP"] = (pred_label & sitk.Not(true_label))
    predictions[s]["TN"] = (sitk.Not(pred_label) & sitk.Not(true_label))
    predictions[s]["FN"] = (sitk.Not(pred_label) & true_label)

In [None]:
"""
Visualise!
"""

vis = ImageVisualiser(target["MRI"], cut=get_com(target["LABELS"]["PAROTID_L"]), window=(0,400), figure_size_in=6)

# We are using scalar overlays, but we could try contours instead!
vis.add_scalar_overlay(
    predictions["PAROTID_L"]["TP"], colormap=plt.cm.Greens, show_colorbar=False, alpha=0.75
)
vis.add_scalar_overlay(
    predictions["PAROTID_L"]["FP"], colormap=plt.cm.Reds, show_colorbar=False, alpha=0.75
)
vis.add_scalar_overlay(
    predictions["PAROTID_L"]["FN"], colormap=plt.cm.Blues, show_colorbar=False, alpha=0.75
)

vis.set_limits_from_label(target["LABELS"]["PAROTID_L"], expansion=30)

handles = [
    mpatches.Patch(color=plt.cm.Greens(240), label=f"True Positive"),
    mpatches.Patch(color=plt.cm.Reds(240), label=f"False Positive"),
    mpatches.Patch(color=plt.cm.Blues(240), label=f"False Negative"),
]

fig=vis.show()

# Remember to add the legend after! Before this, 'fig' isn't defined!
fig.legend(handles=handles, loc=2, borderaxespad=0, bbox_to_anchor=(0.6, 0.8), bbox_transform=fig.transFigure);

In [None]:
fig.savefig("./figures/mri_atlas_results_analysis.jpeg", dpi=300)

In [None]:
"""
Another fairly common visualisation task:
Displaying sequential segmentations
Some examples:
- auto contours derived by increasing a parameter
- auto contours from using more training epochs
- auto contours from expanding/contracting a volume
"""

# Generate some additional contours by expanding our auto-contour

auto_contour_ext, tfm_ext, dvf_ext = generate_field_asymmetric_extend(
    auto_contours["PAROTID_L"],
    vector_asymmetric_extend=(-2, 1, -4),
    gaussian_smooth=2,
)

In [None]:
vis = ImageVisualiser(target["MRI"], cut=get_com(target["LABELS"]["PAROTID_L"]), window=(0,400), figure_size_in=6)
vis.add_vector_overlay(
    dvf_ext,
    name="DVF magnitude [mm]",
    colormap=plt.cm.viridis,
    alpha=0.75,
    arrow_scale=2,
    arrow_width=0.25,
    subsample=(1,4,4),
    color_function='magnitude',
    show_colorbar=True,
)

vis.add_contour(auto_contours["PAROTID_L"], name="Original Auto-Contour", color="red")
vis.add_contour(auto_contour_ext, name="Extended Auto-Contour", color="purple", show_legend=True)

vis.set_limits_from_label(target["LABELS"]["PAROTID_L"], expansion=30)

fig=vis.show()

In [None]:
"""
Generate a sequence of extended contours
"""

extended_contours = {}

for ext in [1,2,3,4,5]:
    auto_contour_ext, _, _ = generate_field_asymmetric_extend(
        auto_contours["PAROTID_L"],
        vector_asymmetric_extend=(-1*ext, 0.5*ext, -2*ext),
        gaussian_smooth=2,
    )
    
    extended_contours[ext] = auto_contour_ext

In [None]:
"""
Visualise
"""

vis = ImageVisualiser(target["MRI"], cut=get_com(target["LABELS"]["PAROTID_L"]), window=(0,400), figure_size_in=6)

vis.add_contour(extended_contours, colormap=plt.cm.plasma_r)

vis.set_limits_from_label(target["LABELS"]["PAROTID_L"], expansion=30)

fig=vis.show()