# Single-atlas segmentation

This notebook demonstrates use of [scikit-rt](https://github.com/scikit-rt/scikit-rt) for single-atlas segmentation, and for evaluation of the results obtained.  

In single-atlas segmentation, segmentations from one image (atlas) are mapped to another image (target).  There are two similar strategies, corresponding to the two types of ROI representation:

- pull strategy: the atlas (moving) is registered to the target (fixed), then ROI masks are pulled from former to latter using the registration transform;
- push strategy: the target (moving) is registered to the atlas (fixed), then ROI contours are pushed from latter to former using the registration transform.

In `scikit-rt`, the underlying image registration can be performed with [elastix](https://elastix.lumc.nl) and with [NiftyReg](http://cmictig.cs.ucl.ac.uk/wiki/index.php/NiftyReg).  With the former, both pull and push strategies are enabled.  With the latter, only the pull strategy is enabled.

In optimisation, atlas and target each have their own segmentations.  Segmentations mapped from atlas to target are then compared with the original segmentations of the target.

Documentation for `scikit-rt` is at:  
[https://scikit-rt.github.io/scikit-rt/](https://scikit-rt.github.io/scikit-rt/)

This notebook uses the dataset:

Peihan Li, "SPECT_CT_data.zip", Figshare dataset (2020)  
https://doi.org/10.6084/m9.figshare.12579707.v1

If not already present on the computer where this notebook is run, the dataset will be downloaded to the directory specified by `topdir` in the
first code cell below.  The download file has a size of 1.6 GB, so download may take a while.

## Module import and data download

The following imports modules needed for this example, defines the path to the data directory, downloads the example dataset if not already present, obtains the list of paths to patient folders, and sets some viewer options.

In [None]:
from pathlib import Path
from skrt import set_viewer_options, BetterViewer, Patient
from skrt.core import alphanumeric, compress_user, download, Defaults
from skrt.dose import sum_doses
from skrt.registration import get_default_pfiles
from skrt.segmentation import SingleAtlasSegmentation

# Define URL of source dataset, and local data directory.
url = "https://figshare.com/ndownloader/files/23528954/SPECT_CT_data.zip"
topdir = Path("~/data/spect_ct").expanduser()
datadir = topdir / Path(url).stem

# Download dataset if not already present.
if not datadir.exists():
    download(url, topdir, unzip=True)
    
# Obtain sorted list of paths to patient folders.
paths = sorted(list(datadir.glob("0*")))
    
# Set Matplotlib runtime configuration (optional).
set_viewer_options()

# In place of interactive images, display static graphics that can be saved with notebook.
# Defaults().no_ui = True
# Omit user part of paths when printing object attributes.
Defaults().compress_user = True

## Sample data

The following defines some sample data.  The indices for the paths to patient folders can be changed, but the ROIs outlined, and their labels, may be different.

In [None]:
# Map possible ROI labels to standardised names.
roi_names = {
    "heart": "heart",
    "lung_left": "lung_l*",
    "lung_right": "lung_r*",
    "spinal_cord": ["cord*", "spinal*cord"],
}

# Load data from selected paths.
indices = [10, 16]
patients = [Patient(paths[idx], unsorted_dicom=True) for idx in indices]

# Obtain references to filtered structure sets.
structure_sets = [p.get_structure_sets("ct")[0].filtered_copy(names=roi_names, keep_renamed_only=True)
                  for p in patients]

# Obtain references to ct images.
images = [ss.get_image() for ss in structure_sets]

# Assign filtered structure sets to images.
[im.assign_structure_set(ss) for im, ss in zip(images, structure_sets)]

# Obtain references to summed doses.
doses = [sum_doses(im.get_doses()) for im in images]

In [None]:
## Setting up single-atlas-segmentation

In [None]:
# Set paths to directories containing registration software.
engine_dirs = {
    "elastix": "~/sw/elastix-5.0.1-mac",
    "niftyreg": "~/sw/niftyreg",
}

# Choose registration engine.
engine = "elastix"

# Set indices for target and atlas (one should be 0, and the other should be 1).
target = 0
atlas = 1

# Set ROI for initial alignment, and margins around it for cropping.
roi_to_align = "heart"
crop_margins = (1000, (-100, 100), 100)

# Set voxel size (larger x-y dimensions, to reduce computing time).
voxel_size = (2, 2, None)

# Define intensity bands: set intensities of 80 and lower to -1024.
bands = {-1024:(None, 80)}

In [None]:
# Create the segmentation instance.
sas = SingleAtlasSegmentation(
    engine=engine,
    engine_dir=engine_dirs[engine],
    im1=images[target],
    im2=images[atlas],
    workdir=Path(f"sas_results/{patients[target].id}_{patients[atlas].id}"),
    roi_names=None,
    initial_crop_focus=roi_to_align,
    initial_crop_margins=crop_margins,
    initial_alignment=roi_to_align,
    voxel_size1=voxel_size,
    bands1=bands,
    pfiles1={"bspline": get_default_pfiles("*BSpline15*", engine)[0]},             
    default_roi_crop_margins=(20, 20, 20),
    roi_crop_margins={"heart": (20, 20, 20)},
    voxel_size2=voxel_size,
    auto=True,
    auto_reg_setup_only=True,
    default_step=-1,
    default_strategy="pull",
    overwrite=True,
    capture_output=True,
    keep_tmp_dir = False,
    log_level="INFO",
)

In [None]:
# Display the original images and segmentations.
sas.im1.view(init_view="y-z", images=sas.im2, rois=[sas.ss1_filtered, sas.ss2_filtered], match_axes="y", comparison=True);

In [None]:
# Show results at each step of registration.
roi_name = "heart"
for step in ["global"]:
    reg = sas.get_registration(roi_name=roi_name, step=step)
    for reg_step in reg.steps:
        print(step, reg_step, reg.get_mutual_information(reg_step, variant="iqr"))
        reg.view_result(reg_step, init_view="y-z", rois=[sas.ss1_filtered[roi_name], sas.get_segmentation(reg_step=reg_step)[roi_name]])

In [None]:
# Display the final result.
ss2_global = sas.get_segmentation(strategy="pull", step="global", reg_step="initial_alignment")
ss2_local = sas.get_segmentation(strategy="pull", step="local", reg_step=["bspline"])
for roi_name in roi_names:
    ss2_global[roi_name].set_color("blue")
    ss2_local[roi_name].set_color("red")
    sas.ss1_filtered[roi_name].set_color("gold")


rois = sas.ss1_filtered + ss2_global + ss2_local
rois.set_image(sas.im1, add_to_image=False)

sas.im1.view(rois=rois, legend=False, init_view="x-y", figsize=10, zoom=1, zoom_ui=True);

In [None]:
# Compare segmentations mapped to target and original segmentations of target.
df = sas.get_comparison(to_keep=list(roi_names), metrics=["dice"], steps=True, reg_steps=True)
for roi_name in roi_names:
    print(df[df["ROI"] == roi_name])
    print()