# Segmentation Visualization Dashboard
Interactive viewer for CT images and their CADS segmentations (Task 559 + Brain).

In [1]:
%matplotlib inline
import os
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.patches as mpatches

from cads.dataset_utils.bodyparts_labelmaps import map_taskid_to_labelmaps

# --- Global Config ---
LABELMAPS = {
    "Task 559 + Brain": map_taskid_to_labelmaps[559].copy(),
    "Task 553": map_taskid_to_labelmaps[553].copy(),
}
LABELMAPS["Task 559 + Brain"][11] = "brain"

In [4]:
class SegmentationDashboard:
    def __init__(self):
        self.img_vol = None
        self.seg_vol = None
        self.view_names = {2: "Axial", 0: "Sagittal", 1: "Coronal"}
        self.cmap = plt.get_cmap("tab20")

        # UI: Input Section
        self.img_path_input = widgets.Textarea(
            value="/gpfs/accounts/jjparkcv_root/jjparkcv98/minsukc/MRI2CT/SynthRAD_combined/3.0x3.0x3.0mm/val/1ABB139/ct.nii.gz",
            description="CT Path:",
            layout=widgets.Layout(width="70%", height="50px"),
        )
        self.seg_dropdown = widgets.Dropdown(description="Seg File:", options=[("None", None)], value=None, layout=widgets.Layout(width="70%"))
        self.labelmap_dropdown = widgets.Dropdown(description="Labelmap:", options=list(LABELMAPS.keys()), value="Task 559 + Brain", layout=widgets.Layout(width="30%"))
        self.load_btn = widgets.Button(description="Load Data", button_style="primary", layout=widgets.Layout(width="15%"))

        # UI: Viewer Section
        self.slice_slider = widgets.IntSlider(description="Slice:", continuous_update=True, layout=widgets.Layout(width="50%"))
        self.axis_dropdown = widgets.Dropdown(description="View:", options=[("Axial", 2), ("Sagittal", 0), ("Coronal", 1)], value=2)

        # UI: Windowing
        self.window_dropdown = widgets.Dropdown(
            description="Window:",
            options=[
                ("Full Range", "full"),
                ("CT Bone (-500, 1300)", (-500, 1300)),
                ("CT Soft Tissue (-150, 350)", (-150, 350)),
                ("CT Lung (-1000, 400)", (-1000, 400)),
                ("Brain (-100, 100)", (-100, 100)),
                ("Custom", "custom"),
            ],
            value="full",
        )
        self.vmin_input = widgets.FloatText(value=-200, description="vmin:", layout=widgets.Layout(width="150px"))
        self.vmax_input = widgets.FloatText(value=400, description="vmax:", layout=widgets.Layout(width="150px"))
        self.custom_box = widgets.HBox([self.vmin_input, self.vmax_input], layout=widgets.Layout(visibility="hidden"))

        self.alpha_slider = widgets.FloatSlider(value=0.5, min=0, max=1.0, step=0.05, description="Alpha:")

        # UI: Label Checks
        self.label_checks = []
        self.label_box = widgets.VBox(
            [widgets.Label(value="Visible Labels:")],
            layout=widgets.Layout(padding="10px", border="1px solid #ccc", width="auto", max_height="300px", overflow_y="scroll"),
        )
        self.update_label_checks()

        self.status_output = widgets.Output()
        self.viewer_output = widgets.Output()

        # Event Handlers
        self.load_btn.on_click(self.on_load_click)
        self.seg_dropdown.observe(self.on_seg_change, names="value")
        self.labelmap_dropdown.observe(self.on_labelmap_change, names="value")
        self.slice_slider.observe(self.update_plot, names="value")
        self.axis_dropdown.observe(self.on_axis_change, names="value")
        self.window_dropdown.observe(self.on_window_change, names="value")
        self.vmin_input.observe(self.update_plot, names="value")
        self.vmax_input.observe(self.update_plot, names="value")
        self.alpha_slider.observe(self.update_plot, names="value")

        # Layout Assembly
        self.input_ui = widgets.VBox([widgets.HBox([self.img_path_input, self.load_btn]), widgets.HBox([self.seg_dropdown, self.labelmap_dropdown])])
        self.window_ui = widgets.HBox([self.window_dropdown, self.custom_box])
        self.viewer_ui = widgets.HBox(
            [
                widgets.VBox([widgets.HBox([self.slice_slider, self.axis_dropdown]), self.window_ui, self.alpha_slider], layout=widgets.Layout(width="60%")),
                self.label_box,
            ]
        )

        display(self.input_ui, self.status_output, self.viewer_ui, self.viewer_output)

    def update_label_checks(self):
        labelmap = LABELMAPS[self.labelmap_dropdown.value]
        self.label_checks = []
        sorted_labels = sorted(labelmap.items())
        for idx, name in sorted_labels:
            if idx == 0:
                continue
            cb = widgets.Checkbox(value=True, description=f"{idx}: {name}", indent=False)
            cb.observe(self.update_plot, names="value")
            self.label_checks.append((idx, cb))
        self.label_box.children = [widgets.Label(value="Visible Labels:")] + [cb for idx, cb in self.label_checks]

    def on_labelmap_change(self, change):
        self.update_label_checks()
        self.update_plot(None)

    def on_load_click(self, b):
        img_path = self.img_path_input.value.strip()
        if not os.path.exists(img_path):
            with self.status_output:
                clear_output()
                print(f"Error: CT file not found at {img_path}")
            return

        img_dir = os.path.dirname(os.path.abspath(img_path))
        # Find all .nii.gz files with "ct" in the name
        seg_files = [f for f in os.listdir(img_dir) if f.endswith(".nii.gz") and ("ct" in f.lower() or "mask" in f.lower())]
        seg_files = sorted(list(set(seg_files)))

        # Update dropdown options
        options = [("None", None)] + [(f, os.path.join(img_dir, f)) for f in seg_files]
        self.seg_dropdown.options = options
        self.seg_dropdown.value = None  # Don't default to anything

        with self.status_output:
            clear_output()
            print(f"Loading CT: {img_path}")
            self.img_vol = nib.load(img_path).get_fdata()
            self.seg_vol = None

            if self.window_dropdown.value == "full":
                self.vmin_input.value = np.min(self.img_vol)
                self.vmax_input.value = np.max(self.img_vol)

            self.on_axis_change(None)
            self.update_plot(None)
            print("CT Load complete.")

    def on_seg_change(self, change):
        seg_path = self.seg_dropdown.value
        if seg_path and os.path.exists(seg_path):
            with self.status_output:
                print(f"Loading Seg: {seg_path}")
                self.seg_vol = nib.load(seg_path).get_fdata()
                print("Seg Load complete.")
        else:
            self.seg_vol = None
        self.update_plot(None)

    def on_axis_change(self, change):
        if self.img_vol is None:
            return
        ax = self.axis_dropdown.value
        max_s = self.img_vol.shape[ax] - 1
        self.slice_slider.max = max_s
        self.slice_slider.value = max_s // 2

    def on_window_change(self, change):
        val = self.window_dropdown.value
        if val == "custom":
            self.custom_box.layout.visibility = "visible"
        elif val == "full":
            self.custom_box.layout.visibility = "hidden"
            if self.img_vol is not None:
                self.vmin_input.value = np.min(self.img_vol)
                self.vmax_input.value = np.max(self.img_vol)
        else:
            self.custom_box.layout.visibility = "hidden"
            self.vmin_input.value, self.vmax_input.value = val
        self.update_plot(None)

    def get_slice(self, vol, axis, sl):
        if vol is None:
            return None
        if axis == 0:
            return np.rot90(vol[sl, :, :])
        elif axis == 1:
            return np.rot90(vol[:, sl, :])
        return np.rot90(vol[:, :, sl])

    def update_plot(self, change):
        if self.img_vol is None:
            return
        with self.viewer_output:
            try:
                clear_output(wait=True)
                ax, sl = self.axis_dropdown.value, self.slice_slider.value
                vmin, vmax = self.vmin_input.value, self.vmax_input.value
                selected_indices = [idx for idx, cb in self.label_checks if cb.value]
                labelmap = LABELMAPS[self.labelmap_dropdown.value]

                img_slice = self.get_slice(self.img_vol, ax, sl)

                fig, ax_plot = plt.subplots(figsize=(10, 10))
                ax_plot.imshow(img_slice, cmap="gray", vmin=vmin, vmax=vmax)

                if self.seg_vol is not None:
                    seg_slice = self.get_slice(self.seg_vol, ax, sl)
                    masked_seg = np.zeros_like(seg_slice)
                    for idx in selected_indices:
                        masked_seg[seg_slice == idx] = idx

                    display_seg = np.ma.masked_where(masked_seg == 0, masked_seg)
                    if selected_indices:
                        ax_plot.imshow(display_seg, cmap=self.cmap, alpha=self.alpha_slider.value, vmin=1, vmax=20, interpolation="nearest")
                        patches = []
                        for idx in sorted(selected_indices):
                            if idx in labelmap:
                                color = self.cmap((idx - 1) / 19.0)
                                patches.append(mpatches.Patch(color=color, label=f"{idx}: {labelmap[idx]}"))
                        if patches:
                            ax_plot.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left")

                ax_plot.set_title(f"Slice {sl} | {self.view_names[ax]} | Window: [{vmin:.1f}, {vmax:.1f}]")
                ax_plot.axis("off")
                plt.show()
                plt.close(fig)
            except Exception as e:
                print(f"Error: {e}")

In [5]:
SegmentationDashboard()

VBox(children=(HBox(children=(Textarea(value='/gpfs/accounts/jjparkcv_root/jjparkcv98/minsukc/MRI2CT/SynthRAD_…

Output()

HBox(children=(VBox(children=(HBox(children=(IntSlider(value=0, description='Slice:', layout=Layout(width='50%…

Output()

<__main__.SegmentationDashboard at 0x14ae072fbcd0>