# Interactive Viewing of Scans and Segmentation Data

This notebook allows the user to slide through each scan in the series with the option to view the segmentations provided by TotalSegmentator overlaid. The user must copy and run the notebook to view the scans. Selecting a series will load every scan into data and, if available, the segmentation data - hence the user should expect some number of seconds to pass before the first scan is displayed.

In [1]:
# Load libraries.
import os
import glob

import joblib
import pandas as pd
import numpy as np
import nibabel as nib
import ipywidgets as widgets
import matplotlib.pyplot as plt
import pydicom



In [2]:
# Load the train and test meta data.
train_meta = pd.read_csv('/kaggle/input/rsna-2023-abdominal-trauma-detection/train_series_meta.csv')
test_meta = pd.read_csv('/kaggle/input/rsna-2023-abdominal-trauma-detection/test_series_meta.csv')

In [3]:
dicom_tag_columns = [
    'Columns',
    'ImageOrientationPatient',
    'ImagePositionPatient',
    'InstanceNumber',
    'PatientID',
    'PatientPosition',
    'PixelSpacing',
    'RescaleIntercept',
    'RescaleSlope',
    'Rows',
    'SeriesNumber',
    'SliceThickness',
    'path',
    'WindowCenter',
    'WindowWidth'
]
ROOT_DIR = '/kaggle/input/rsna-2023-abdominal-trauma-detection/'
VECTOR_DTYPE = np.float32

def process_dicom_tags(df, copy=True):
    df_out = df.copy() if copy else df
    df_out['SeriesID'] = df_out.path.str.split('/').str.get(2).astype(int)
    # Cast strings as arrays.
    for c in ['ImageOrientationPatient', 'ImagePositionPatient', 'PixelSpacing']:
        df_out[c] = df_out[c].str.strip('[]').str.encode('utf-8').apply(lambda x: np.fromstring(x, sep=',', dtype=VECTOR_DTYPE))
    # Make `path` absolute path.
    df_out['path'] = df_out.path.apply(lambda x: os.path.join(ROOT_DIR, x))
    df_out.sort_values(['SeriesID', 'InstanceNumber'], inplace=True)
    return df_out

In [4]:
# Load DICOM tags and process them.
train_dicom_tags = pd.read_parquet('/kaggle/input/rsna-2023-abdominal-trauma-detection/train_dicom_tags.parquet', columns=dicom_tag_columns)
test_dicom_tags = pd.read_parquet('/kaggle/input/rsna-2023-abdominal-trauma-detection/test_dicom_tags.parquet', columns=dicom_tag_columns)
train_dicom_tags = process_dicom_tags(train_dicom_tags, copy=False)
test_dicom_tags = process_dicom_tags(test_dicom_tags, copy=False)

In [5]:
# Define widgets.
output = widgets.Output()

dataset_widget = widgets.Dropdown(
    options=['Train', 'Test'],
    value=None,
    description='Dataset',
    disabled=False
)
filter_series_id_without_segmentation_enable_widget = widgets.Checkbox(
    value=False,
    description='Filter series ID without segmentations',
    disabled=False,
    indent=False
)
series_id_widget = widgets.Dropdown(
    options=[''],
    description='Series ID',
    disabled=False
)
scan_id_widget = widgets.IntSlider(
    value=0,
    min=0,
    max=1,
    step=1,
    description='Scan ID',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)
segmentation_enable_widget = widgets.Checkbox(
    value=False,
    description='Apply Segmentation',
    disabled=True,
    indent=False
)
window_range_widget = widgets.FloatRangeSlider(
    value=[-1024.0, 1024.0],
    min=-1024.0,
    max=1024.0,
    step=0.1,
    description='Window Range (HU):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

In [6]:
# Define TotalSegmentation codes.
total_segmentator_codes = {
    1: 'liver',
    2: 'spleen',
    3: 'kidney_left',
    4: 'kidney_right',
    5: 'bowel',
}

In [7]:
def calibrate_img(img, m, c):
    # Transform into HU.
    img *= m
    img += c
    #img[img < hu_min] = hu_min # Clip pixels outside the FOV. 
    return img


def get_patient_id(series_id, use_train_set):
    df = train_meta if use_train_set else test_meta
    return int(df[df.series_id == series_id].iloc[0].patient_id)


def get_scan_files(patient_id, series_id, use_train_set):
    dataset_dir = 'train_images' if use_train_set else 'test_images'
    fnames = glob.glob(f'/kaggle/input/rsna-2023-abdominal-trauma-detection/{dataset_dir}/{patient_id}/{series_id}/*.dcm')
    return list(sorted(fnames, key=lambda f: int(os.path.splitext(os.path.basename(f))[0])))


def load_series(series_id, use_train_set):
    df = train_dicom_tags if use_train_set else test_dicom_tags
    sub_df = df[df.SeriesID == series_id]

    @joblib.delayed
    def load_dcm(path):
        dcm = pydicom.dcmread(path)
        return dcm.pixel_array.astype(np.float32)

    img = joblib.Parallel(n_jobs=-1)(load_dcm(path) for path in sub_df.path.values)
    img = np.dstack(img)
    
    # Make sure images are ordered correctly (superior is +z).
    m, c, iop = sub_df.iloc[0][['RescaleSlope', 'RescaleIntercept', 'ImageOrientationPatient']]
    img = calibrate_img(img, m, c)
    imaging_axis = np.cross(iop[:3], iop[3:])
    distance_projection = np.dot(np.vstack(sub_df.ImagePositionPatient.values), imaging_axis)
    img = img[:, :, np.argsort(distance_projection)]
    img = img.transpose((1, 0, 2))
    return img


def get_segmentation_paths():
    paths = glob.glob(f'/kaggle/input/rsna-2023-abdominal-trauma-detection/segmentations/*.nii')
    return list(sorted(paths, key=lambda f: int(os.path.splitext(os.path.basename(f))[0])))


def get_series_ids_with_segmentations():
    return [int(os.path.splitext(os.path.basename(f))[0]) for f in get_segmentation_paths()]
    

def get_segmentation_path(series_id):
    return f'/kaggle/input/rsna-2023-abdominal-trauma-detection/segmentations/{series_id}.nii'


def load_segmentation(x):
    if not isinstance(x, str):
        x = get_segmentation_path(int(x))
    out = nib.load(x)
    return nib.as_closest_canonical(out) # Ensures output is RAS, DICOM is LPS.


def get_segmentation_image(seg):
    if not isinstance(seg, nib.nifti1.Nifti1Image):
        seg = load_segmentation(seg)
    return seg.get_fdata().astype(int)[::-1, ::-1, :]

In [8]:
from matplotlib.patches import Patch
from matplotlib.lines import Line2D


CURRENT_SCAN_DATA = None
CURRENT_SEGMENTATOR_DATA = None
IM = None
SEG_IM = None
fig, ax = plt.subplots()
plt.close(fig)

series_ids_with_segmentations = get_series_ids_with_segmentations()

segmentator_cmap = plt.get_cmap('rainbow')
segmentator_palette = np.array([[np.nan] * 4] + [segmentator_cmap(x) for x in np.linspace(0, 1, 6)])


#@output.capture()
def handle_dataset_update(change):
    df = train_meta if change['new'] == 'Train' else test_meta
    series_id_widget.options = sorted(df.series_id.values)
    series_id_widget.value = None

    
#@output.capture()
def handle_filter_series_id(change):
    df = train_meta if dataset_widget.value == 'Train' else test_meta
    options = df.series_id.values
    if change['new']:
        options = [series_id for series_id in options if series_id in series_ids_with_segmentations]
    series_id_widget.options = list(sorted(options))


#@output.capture()
def handle_series_id_update(change):
    global CURRENT_SCAN_DATA, CURRENT_SEGMENTATOR_DATA, SEGMENTATOR_DATA_PATH

    # Load scans and update bounds on scan ID widget.
    series_id = change['new']
    if not series_id:
        return
    CURRENT_SCAN_DATA = load_series(series_id, dataset_widget.value == 'Train')
    
    scan_id_widget.value = 0
    scan_id_widget.max = CURRENT_SCAN_DATA.shape[-1] - 1
    
    # Load segmentator data.
    if series_id in series_ids_with_segmentations:
        CURRENT_SEGMENTATOR_DATA = get_segmentation_image(series_id)
        segmentation_enable_widget.disabled = False
        segmentation_enable_widget.description = 'Show Segmentations'
    else:
        segmentation_enable_widget.disabled = True
        segmentation_enable_widget.description = 'Segmentations not available.'

        
#@output.capture()
def update_scan(dataset, filter_series_id_without_segmentation, series_id, scan_id, show_segmentations, window_range):
    global CURRENT_SCAN_DATA, CURRENT_SEGMENTATOR_DATA, IM, SEG_IM
    if series_id is None or scan_id is None or CURRENT_SCAN_DATA is None:
        return
    data = CURRENT_SCAN_DATA[:, :, scan_id].T
    data = np.clip(data, a_min=window_range[0], a_max=window_range[1])

    if IM is None:
        IM = ax.imshow(data, cmap=plt.cm.bone)
    else:
        IM.set_data(data)
        IM.set_clim(data.min(), data.max())

    if show_segmentations:
        segmentator_slice = CURRENT_SEGMENTATOR_DATA[:, :, scan_id].T
        if SEG_IM is None:
            SEG_IM = ax.imshow(segmentator_palette[segmentator_slice], alpha=0.5)
        else:
            SEG_IM.set_data(segmentator_palette[segmentator_slice])
        mask_levels = [x for x in np.unique(segmentator_slice) if x != 0]
        legend_elements = [
            Patch(facecolor=segmentator_palette[x], edgecolor=segmentator_palette[x], label=total_segmentator_codes[x]) for x in mask_levels
        ]
        if ax.get_legend() is not None:
            ax.get_legend().remove()
        ax.legend(handles=legend_elements, loc='upper right')
    display(fig)


dataset_widget.observe(handle_dataset_update, names='value')
filter_series_id_without_segmentation_enable_widget.observe(handle_filter_series_id, names='value')
series_id_widget.observe(handle_series_id_update, names='value')

w = widgets.interactive(
    update_scan,
    dataset=dataset_widget,
    filter_series_id_without_segmentation=filter_series_id_without_segmentation_enable_widget,
    series_id=series_id_widget,
    scan_id=scan_id_widget,
    show_segmentations=segmentation_enable_widget,
    window_range=window_range_widget,
)
w.children[-1].layout.height = '450px'
display(w)

interactive(children=(Dropdown(description='Dataset', options=('Train', 'Test'), value=None), Checkbox(value=F…