# CADS Interactive Segmentation Dashboard (Task 559)

This dashboard allows you to run and visualize CT segmentations for **Task 559** (General tissues, body cavities, bones).

### **Note on Output Format:**
The output is a **Categorical Label Map** (integer indices)
Each voxel contains a value from 0 to 10 corresponding to the detected structure:
0: Background, 1: Subcutaneous tissue, 2: Muscle, 3: Abdominal cavity, 4: Thoracic cavity, 5: Bones, 6: Gland structure, 7: Pericardium, 8: Breast implant, 9: Mediastinum, 10: Spinal cord.

In [1]:
%matplotlib inline
import os
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
import shutil
from IPython.display import display, clear_output
from matplotlib.colors import ListedColormap

from cads.utils.libs import setup_nnunet_env, check_or_download_model_weights, get_model_weights_dir
from cads.utils.inference import predict
from cads.dataset_utils.bodyparts_labelmaps import map_taskid_to_labelmaps

# --- Global Config ---
gpfs_weights_path = "/gpfs/accounts/jjparkcv_root/jjparkcv98/minsukc/MRI2CT/cads_weights"
os.environ["CADS_WEIGHTS_PATH"] = gpfs_weights_path
os.makedirs(gpfs_weights_path, exist_ok=True)

setup_nnunet_env()
MODEL_FOLDER = get_model_weights_dir()
TASK_ID = 559
LABELMAP = map_taskid_to_labelmaps[TASK_ID]

print(f"Model weights directory: {MODEL_FOLDER}")
check_or_download_model_weights(TASK_ID)

nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.
nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing or training. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up.
nnUNet_results is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this up.
Model weights directory: /gpfs/accounts/jjparkcv_root/jjparkcv98/minsukc/MRI2CT/cads_weights


In [None]:
def run_cads_segmentation(img_path):
    """
    Runs the CADS Task 559 segmentation and saves the result as 'cads_ct_seg.nii.gz'
    in the same directory as the input image.
    """
    img_dir = os.path.dirname(os.path.abspath(img_path))
    filename = os.path.basename(img_path)
    patient_id = filename[:-7] if filename.endswith(".nii.gz") else os.path.splitext(filename)[0]

    # 1. Run prediction into a temporary folder inside input dir
    temp_out = os.path.join(img_dir, "cads_temp")
    os.makedirs(temp_out, exist_ok=True)

    predict(
        files_in=[img_path],
        folder_out=temp_out,
        model_folder=MODEL_FOLDER,
        task_ids=[TASK_ID],
        folds="all",
        use_cpu=not torch.cuda.is_available(),
        preprocess_cads=True,
        postprocess_cads=True,
        save_all_combined_seg=False,
        verbose=False,
    )

    # 2. Move and rename the target file
    raw_seg_path = os.path.join(temp_out, patient_id, f"{patient_id}_part_{TASK_ID}.nii.gz")
    final_seg_path = os.path.join(img_dir, "cads_ct_seg.nii.gz")

    if os.path.exists(raw_seg_path):
        shutil.move(raw_seg_path, final_seg_path)
        # Cleanup temp
        shutil.rmtree(temp_out)
        print(f"Successfully saved segmentation to: {final_seg_path}")
        return final_seg_path
    else:
        print("Error: Segmentation failed to generate.")
        return None

In [None]:
class SegmentationDashboard:
    def __init__(self):
        self.img_vol = None
        self.seg_vol = None
        self.img_path = ""
        self.seg_path = ""
        self.view_names = {2: "Axial", 0: "Sagittal", 1: "Coronal"}
        self.cmap = plt.get_cmap("tab10")

        # UI: Input Section
        self.path_input = widgets.Text(value="dataset/1.0x1.0x1.0mm/val/1ABB139/ct.nii.gz", description="CT Path:", layout=widgets.Layout(width="75%"))
        self.run_btn = widgets.Button(description="Run Segmentation", button_style="success", layout=widgets.Layout(width="20%"))

        # UI: Viewer Section
        self.slice_slider = widgets.IntSlider(description="Slice:", continuous_update=True, layout=widgets.Layout(width="50%"))
        self.axis_dropdown = widgets.Dropdown(description="View:", options=[("Axial", 2), ("Sagittal", 0), ("Coronal", 1)], value=2)

        # UI: Windowing
        self.vmin_input = widgets.FloatText(value=-200, description="vmin:", layout=widgets.Layout(width="180px"))
        self.vmax_input = widgets.FloatText(value=400, description="vmax:", layout=widgets.Layout(width="180px"))

        self.alpha_slider = widgets.FloatSlider(value=0.5, min=0, max=1.0, step=0.05, description="Alpha:")

        # UI: Label Checks
        self.label_checks = []
        for idx, name in LABELMAP.items():
            if idx == 0:
                continue
            cb = widgets.Checkbox(value=True, description=name, indent=False)
            cb.observe(self.update_plot, names="value")
            self.label_checks.append((idx, cb))

        self.label_box = widgets.VBox(
            [widgets.Label(value="Visible Labels:")] + [cb for idx, cb in self.label_checks], layout=widgets.Layout(padding="10px", border="1px solid #ccc", width="auto", max_height="250px")
        )

        self.status_output = widgets.Output()
        self.viewer_output = widgets.Output()

        # Event Handlers
        self.run_btn.on_click(self.on_run_click)
        self.slice_slider.observe(self.update_plot, names="value")
        self.axis_dropdown.observe(self.on_axis_change, names="value")
        self.vmin_input.observe(self.update_plot, names="value")
        self.vmax_input.observe(self.update_plot, names="value")
        self.alpha_slider.observe(self.update_plot, names="value")

        # Layout Assembly
        self.input_ui = widgets.HBox([self.path_input, self.run_btn])
        self.viewer_ui = widgets.HBox(
            [
                widgets.VBox([widgets.HBox([self.slice_slider, self.axis_dropdown]), widgets.HBox([self.vmin_input, self.vmax_input]), self.alpha_slider], layout=widgets.Layout(width="60%")),
                self.label_box,
            ]
        )

        display(self.input_ui, self.status_output, self.viewer_ui, self.viewer_output)

    def on_run_click(self, b):
        self.img_path = self.path_input.value
        if not os.path.exists(self.img_path):
            with self.status_output:
                clear_output()
                print(f"Error: File not found at {self.img_path}")
            return

        # Check if segmentation already exists in input directory
        img_dir = os.path.dirname(os.path.abspath(self.img_path))
        self.seg_path = os.path.join(img_dir, "cads_ct_seg.nii.gz")

        with self.status_output:
            clear_output()
            if os.path.exists(self.seg_path):
                print(f"Found existing segmentation: {self.seg_path}")
            else:
                print(f"No segmentation found. Running CADS for {self.img_path}...")
                self.seg_path = run_cads_segmentation(self.img_path)

        if self.seg_path and os.path.exists(self.seg_path):
            self.img_vol = nib.load(self.img_path).get_fdata()
            self.seg_vol = nib.load(self.seg_path).get_fdata()
            self.on_axis_change(None)
            self.update_plot(None)

    def on_axis_change(self, change):
        if self.img_vol is None:
            return
        ax = self.axis_dropdown.value
        max_s = self.img_vol.shape[ax] - 1
        self.slice_slider.max = max_s
        self.slice_slider.value = max_s // 2

    def get_slice(self, vol, axis, sl):
        if vol is None:
            return None
        if axis == 0:
            return np.rot90(vol[sl, :, :])
        elif axis == 1:
            return np.rot90(vol[:, sl, :])
        return np.rot90(vol[:, :, sl])

    def update_plot(self, change):
        if self.img_vol is None or self.seg_vol is None:
            return

        with self.viewer_output:
            try:
                clear_output(wait=True)
                print(f"Visualizing Image: {self.img_path}")
                print(f"Visualizing Seg:   {self.seg_path}")

                ax, sl = self.axis_dropdown.value, self.slice_slider.value
                vmin, vmax = self.vmin_input.value, self.vmax_input.value
                selected_indices = [idx for idx, cb in self.label_checks if cb.value]

                img_slice = self.get_slice(self.img_vol, ax, sl)
                seg_slice = self.get_slice(self.seg_vol, ax, sl)

                masked_seg = np.zeros_like(seg_slice)
                for idx in selected_indices:
                    masked_seg[seg_slice == idx] = idx

                display_seg = np.ma.masked_where(masked_seg == 0, masked_seg)

                fig, ax_plot = plt.subplots(figsize=(10, 10))
                ax_plot.imshow(img_slice, cmap="gray", vmin=vmin, vmax=vmax)

                if selected_indices:
                    ax_plot.imshow(display_seg, cmap=self.cmap, alpha=self.alpha_slider.value, vmin=1, vmax=10, interpolation="nearest")
                    import matplotlib.patches as mpatches

                    patches = [mpatches.Patch(color=self.cmap((idx - 1) % 10), label=LABELMAP[idx]) for idx in sorted(selected_indices)]
                    ax_plot.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left")

                view_name = self.view_names.get(ax, "Unknown")
                ax_plot.set_title(f"Slice {sl} | View: {view_name} | Window: [{vmin:.1f}, {vmax:.1f}]")
                ax_plot.axis("off")
                plt.show()
                plt.close(fig)
            except Exception as e:
                print(f"Error in viewer: {e}")

In [3]:
SegmentationDashboard()

HBox(children=(Text(value='dataset/1.0x1.0x1.0mm/val/1ABB139/ct.nii.gz', description='CT Path:', layout=Layout…

Output()

HBox(children=(VBox(children=(HBox(children=(IntSlider(value=0, description='Slice:', layout=Layout(width='50%…

Output()

<__main__.SegmentationDashboard at 0x149f62fd1030>