# MRI2CT Visualization Dashboard
Interactive viewer for generated CT volumes compared with Ground Truth.

In [9]:
import os
import glob
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.patches as mpatches

# Import centralized config from src
import sys
sys.path.append(os.path.abspath("."))
from src.data import get_region_key, REGION_MAPS

# Configuration
PRED_ROOT = "./dataset/predictions"
DATA_ROOT = "./dataset"

def get_runs():
    if not os.path.exists(PRED_ROOT): return []
    return sorted([d for d in os.listdir(PRED_ROOT) if os.path.isdir(os.path.join(PRED_ROOT, d))], reverse=True)

def get_epochs(run_name):
    run_dir = os.path.join(PRED_ROOT, run_name)
    if not os.path.exists(run_dir): return []
    dirs = [d for d in os.listdir(run_dir) if d.startswith("epoch_")]
    try: dirs.sort(key=lambda x: int(x.split('_')[1]), reverse=True)
    except: dirs.sort(reverse=True)
    return dirs

def get_subjects(run_name, epoch):
    ep_dir = os.path.join(PRED_ROOT, run_name, epoch)
    if not os.path.exists(ep_dir): return []
    files = glob.glob(os.path.join(ep_dir, "*pred_*.nii.gz"))
    subjs = []
    for f in files:
        name = os.path.basename(f)
        if "_seg" in name: continue
        sid = name.split("_pred_")[1].replace(".nii.gz", "") if "_pred_" in name else name.replace("pred_", "").replace(".nii.gz", "")
        subjs.append(sid)
    return sorted(list(set(subjs)))

def create_random_colormap(num_classes=120):
    np.random.seed(42)
    colors = np.random.rand(num_classes + 1, 4)
    colors[:, 3] = 1.0; colors[0, :] = [0, 0, 0, 0]
    return ListedColormap(colors)

class NiftiViewer:
    def __init__(self):
        self.num_classes = 120
        self.seg_cmap = create_random_colormap(self.num_classes)
        self.run_dropdown = widgets.Dropdown(description="Run:", options=[])
        self.epoch_dropdown = widgets.Dropdown(description="Epoch:", options=[])
        self.subj_dropdown = widgets.Dropdown(description="Subject:", options=[])
        self.slice_slider = widgets.IntSlider(description="Slice:", continuous_update=True)
        self.axis_dropdown = widgets.Dropdown(description="View:", options=[("Axial", 2), ("Sagittal", 0), ("Coronal", 1)], value=2)
        self.window_dropdown = widgets.Dropdown(
            description="Window:",
            options=[
                ("Full Range (-1024, 1024)", (-1024, 1024)),
                ("Head CT (-100, 100)", (-100, 100)),
                ("Anatomix Default (-450, 450)", (-450, 450)),
                ("Custom", "custom"),
            ],
            value=(-1024, 1024)
        )
        self.vmin_input = widgets.IntText(value=-1024, description="vmin:", layout=widgets.Layout(width="120px"))
        self.vmax_input = widgets.IntText(value=1024, description="vmax:", layout=widgets.Layout(width="120px"))
        self.custom_box = widgets.HBox([self.vmin_input, self.vmax_input], layout=widgets.Layout(visibility="hidden"))
        
        self.show_mri_check = widgets.Checkbox(value=False, description="Show MRI")
        self.show_diff_check = widgets.Checkbox(value=False, description="Show Diff")
        self.show_seg_check = widgets.Checkbox(value=False, description="Show Segs")
        self.show_target_only_check = widgets.Checkbox(value=False, description="Target Organs Only")
        self.show_overlay_check = widgets.Checkbox(value=False, description="Show Seg Overlay")

        self.output = widgets.Output()
        self.pred_vol = self.gt_vol = self.mri_vol = self.gt_seg = self.pred_seg = self.eval_df = None
        
        self.controls = widgets.VBox([
            widgets.HBox([self.run_dropdown, self.epoch_dropdown, self.subj_dropdown]),
            widgets.HBox([self.slice_slider, self.axis_dropdown]),
            widgets.HBox([self.window_dropdown, self.custom_box]),
            widgets.HBox([self.show_mri_check, self.show_diff_check, self.show_seg_check, self.show_target_only_check, self.show_overlay_check])
        ])
        
        self.run_dropdown.observe(self.on_run_change, names="value")
        self.epoch_dropdown.observe(self.on_epoch_change, names="value")
        self.subj_dropdown.observe(self.on_subj_change, names="value")
        self.slice_slider.observe(self.update_plot, names="value")
        self.axis_dropdown.observe(self.update_max_slice, names="value")
        self.window_dropdown.observe(self.on_window_change, names="value")
        for i in [self.vmin_input, self.vmax_input, self.show_mri_check, self.show_diff_check, self.show_seg_check, self.show_target_only_check, self.show_overlay_check]:
            i.observe(self.update_plot, names="value")
        
        self.refresh_runs()
        display(self.controls, self.output)

    def refresh_runs(self): self.run_dropdown.options = get_runs()
    def on_run_change(self, change): self.epoch_dropdown.options = get_epochs(self.run_dropdown.value)
    def on_epoch_change(self, change):
        run, ep = self.run_dropdown.value, self.epoch_dropdown.value
        if not run or not ep: return
        self.eval_df = None
        for n in ["functional_evaluation.csv", "functional_eval_1x1x1.csv", "functional_eval_3x3x3.csv"]:
            p = os.path.join(PRED_ROOT, run, ep, n)
            if os.path.exists(p): self.eval_df = pd.read_csv(p); break
        self.subj_dropdown.options = get_subjects(run, ep)
        self.on_subj_change(None)
            
    def on_window_change(self, change):
        self.custom_box.layout.visibility = "visible" if self.window_dropdown.value == "custom" else "hidden"
        self.update_plot(None)

    def load_gt_data(self, subj_id, filename, target_shape=None):
        res_list = ["3.0x3.0x3.0mm", "1.0x1.0x1.0mm"]
        for res in res_list:
            for split in ["val", "test", "train"]:
                if filename == "mr.nii.gz":
                    paths = [os.path.join(DATA_ROOT, res, split, subj_id, "registration_output", "moved_mr.nii.gz"), os.path.join(DATA_ROOT, res, split, subj_id, "mr.nii.gz")]
                else: paths = [os.path.join(DATA_ROOT, res, split, subj_id, filename)]
                for p in paths:
                    if os.path.exists(p):
                        vol_obj = nib.load(p)
                        if target_shape is None or vol_obj.shape == target_shape: return vol_obj.get_fdata()
        return None

    def on_subj_change(self, change):
        run, ep, subj = self.run_dropdown.value, self.epoch_dropdown.value, self.subj_dropdown.value
        if not (run and ep and subj): return
        self.pred_vol = self.pred_seg = self.gt_vol = self.mri_vol = self.gt_seg = None
        ep_dir = os.path.join(PRED_ROOT, run, ep)
        c_pred = glob.glob(os.path.join(ep_dir, f"*pred_{subj}.nii.gz"))
        c_seg = glob.glob(os.path.join(ep_dir, f"*pred_{subj}_seg.nii.gz"))
        self.pred_vol = nib.load(c_pred[0]).get_fdata() if c_pred else None
        self.pred_seg = nib.load(c_seg[0]).get_fdata() if c_seg else None
        target_shape = self.pred_vol.shape if self.pred_vol is not None else None
        self.gt_vol = self.load_gt_data(subj, "ct.nii.gz", target_shape)
        self.mri_vol = self.load_gt_data(subj, "mr.nii.gz", target_shape)
        self.gt_seg = self.load_gt_data(subj, "ct_seg.nii.gz", target_shape)
        self.update_max_slice(None); self.update_plot(None)

    def update_max_slice(self, change):
        if self.pred_vol is None: return
        ax = self.axis_dropdown.value; max_s = self.pred_vol.shape[ax] - 1
        self.slice_slider.max = max_s
        if self.slice_slider.value > max_s: self.slice_slider.value = max_s // 2

    def get_slice(self, vol, axis, sl):
        if vol is None: return None
        if axis == 0: return np.rot90(vol[sl, :, :])
        elif axis == 1: return np.rot90(vol[:, sl, :])
        return np.rot90(vol[:, :, sl])

    def update_plot(self, change):
        if self.pred_vol is None: return
        ax, sl = self.axis_dropdown.value, self.slice_slider.value
        vmin, vmax = (self.vmin_input.value, self.vmax_input.value) if self.window_dropdown.value == "custom" else self.window_dropdown.value
        
        s_p = self.get_slice(self.pred_vol, ax, sl); s_g = self.get_slice(self.gt_vol, ax, sl)
        s_m = self.get_slice(self.mri_vol, ax, sl); sg_g = self.get_slice(self.gt_seg, ax, sl)
        sg_p = self.get_slice(self.pred_seg, ax, sl)
        
        # Filter target organs if requested
        if self.show_target_only_check.value and self.subj_dropdown.value:
            reg = get_region_key(self.subj_dropdown.value)
            targets = REGION_MAPS.get(reg, {})
            target_ids = []
            for v in targets.values():
                if isinstance(v, list): target_ids.extend(v)
                else: target_ids.append(v)
            if sg_g is not None: sg_g = np.where(np.isin(sg_g, target_ids), sg_g, 0)
            if sg_p is not None: sg_p = np.where(np.isin(sg_p, target_ids), sg_p, 0)

        with self.output:
            clear_output(wait=True); plots = []
            if self.show_mri_check.value and s_m is not None: plots.append(("MRI", s_m, "gray", 0, 1, None))
            plots.append(("GT", s_g, "gray", vmin, vmax, sg_g if self.show_seg_check.value else None))
            plots.append(("Pred", s_p, "gray", vmin, vmax, sg_p if self.show_seg_check.value else None))
            if self.show_diff_check.value and s_g is not None and s_p is not None and s_g.shape == s_p.shape:
                plots.append(("Diff", s_g - s_p, "seismic", -400, 400, None))
            if self.show_overlay_check.value and sg_g is not None and sg_p is not None and sg_g.shape == sg_p.shape:
                overlay = np.zeros((*sg_g.shape, 3))
                overlay[..., 0] = (sg_g > 0).astype(float); overlay[..., 1] = (sg_p > 0).astype(float)
                plots.append(("Overlay (R=GT, G=Pred)", overlay, None, None, None, None))

            n = len(plots); fig, axes = plt.subplots(1, n, figsize=(6 * n, 5))
            if n == 1: axes = [axes]
            for i, (t, img, cmap, v_min, v_max, seg) in enumerate(plots):
                if cmap: 
                    im = axes[i].imshow(img, cmap=cmap, vmin=v_min, vmax=v_max)
                    if t in ["GT", "Pred", "Diff", "MRI"]:
                        fig.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)
                else: 
                    axes[i].imshow(img)
                if seg is not None: axes[i].imshow(seg, cmap=self.seg_cmap, alpha=0.4, vmin=0, vmax=self.num_classes)
                axes[i].set_title(t); axes[i].axis("off")
            
            metrics_text = ""
            if self.eval_df is not None:
                row = self.eval_df[self.eval_df['subj_id'] == self.subj_dropdown.value]
                if not row.empty: metrics_text = f" | Dice: {row['avg_dice'].values[0]:.4f}"
            
            plt.suptitle(f"Subject: {self.subj_dropdown.value} | Slice {sl}{metrics_text}")
            plt.tight_layout(); plt.show()

In [10]:
viewer = NiftiViewer()

VBox(children=(HBox(children=(Dropdown(description='Run:', options=('CNN_Train613_20260204_1811', 'CNN_Train61â€¦

Output()