# MRI2CT Visualization Dashboard
Interactive viewer for generated CT volumes compared with Ground Truth.

In [20]:
import os
import glob
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# Configuration
PRED_ROOT = "/gpfs/accounts/jjparkcv_root/jjparkcv98/minsukc/MRI2CT/SynthRAD_combined/predictions"
DATA_ROOT = "/gpfs/accounts/jjparkcv_root/jjparkcv98/minsukc/MRI2CT/SynthRAD_combined/3.0x3.0x3.0mm" # Source of GT

def get_runs():
    if not os.path.exists(PRED_ROOT):
        return []
    return sorted([d for d in os.listdir(PRED_ROOT) if os.path.isdir(os.path.join(PRED_ROOT, d))])

def get_epochs(run_name):
    run_dir = os.path.join(PRED_ROOT, run_name)
    if not os.path.exists(run_dir):
        return []
    # Sort epoch_X by integer X
    dirs = [d for d in os.listdir(run_dir) if d.startswith("epoch_")]
    try:
        dirs.sort(key=lambda x: int(x.split('_')[1]))
    except:
        dirs.sort()
    return dirs

def get_subjects(run_name, epoch):
    ep_dir = os.path.join(PRED_ROOT, run_name, epoch)
    if not os.path.exists(ep_dir):
        return []
    # pred_1ABA005.nii.gz -> 1ABA005
    files = glob.glob(os.path.join(ep_dir, "pred_*.nii.gz"))
    subjs = [os.path.basename(f).replace("pred_", "").replace(".nii.gz", "") for f in files]
    return sorted(subjs)

# --- Viewer Logic ---

class NiftiViewer:
    def __init__(self):
        self.run_dropdown = widgets.Dropdown(description="Run:", options=[])
        self.epoch_dropdown = widgets.Dropdown(description="Epoch:", options=[])
        self.subj_dropdown = widgets.Dropdown(description="Subject:", options=[])
        
        self.slice_slider = widgets.IntSlider(description="Slice:", min=0, max=100, step=1, continuous_update=True)
        self.axis_dropdown = widgets.Dropdown(description="View:", options=[("Axial", 2), ("Sagittal", 0), ("Coronal", 1)], value=2)
        
        self.window_dropdown = widgets.Dropdown(
            description="Window:",
            options=[
                ("Full Range (-1024, 1024)", (-1024, 1024)),
                ("Head CT (-100, 100)", (-100, 100)),
                ("Anatomix Default (-450, 450)", (-450, 450)),
                ("Custom", "custom"),
            ],
            value=(-1024, 1024)
        )
        
        self.vmin_input = widgets.IntText(value=-1024, description="vmin:", layout=widgets.Layout(width="150px"))
        self.vmax_input = widgets.IntText(value=1024, description="vmax:", layout=widgets.Layout(width="150px"))
        self.custom_box = widgets.HBox([self.vmin_input, self.vmax_input], layout=widgets.Layout(visibility="hidden"))
        
        # New Options
        self.show_mri_check = widgets.Checkbox(value=False, description="Show MRI Input")
        self.show_diff_check = widgets.Checkbox(value=False, description="Show Residual (GT-Pred)")

        self.output = widgets.Output()
        
        # Data Cache
        self.pred_vol = None
        self.gt_vol = None
        self.mri_vol = None
        
        # Layout
        self.controls = widgets.VBox([
            widgets.VBox([
                widgets.HBox([self.run_dropdown, self.epoch_dropdown, self.subj_dropdown]),
                widgets.HBox([self.slice_slider, self.axis_dropdown]),
                widgets.HBox([self.window_dropdown, self.custom_box]),
                widgets.HBox([self.show_mri_check, self.show_diff_check])
            ])
        ])
        
        # Event Bindings
        self.run_dropdown.observe(self.on_run_change, names="value")
        self.epoch_dropdown.observe(self.on_epoch_change, names="value")
        self.subj_dropdown.observe(self.on_subj_change, names="value")
        
        self.slice_slider.observe(self.update_plot, names="value")
        self.axis_dropdown.observe(self.update_max_slice, names="value")
        self.window_dropdown.observe(self.on_window_change, names="value")
        self.vmin_input.observe(self.update_plot, names="value")
        self.vmax_input.observe(self.update_plot, names="value")
        
        self.show_mri_check.observe(self.update_plot, names="value")
        self.show_diff_check.observe(self.update_plot, names="value")
        
        # Init
        self.refresh_runs()
        display(self.controls, self.output)
    
    def refresh_runs(self):
        runs = get_runs()
        self.run_dropdown.options = runs
        if runs:
            self.run_dropdown.value = runs[0]
            self.on_run_change(None)
            
    def on_run_change(self, change):
        run = self.run_dropdown.value
        if not run: return
        epochs = get_epochs(run)
        self.epoch_dropdown.options = epochs
        if epochs:
            self.epoch_dropdown.value = epochs[-1] # Default to latest
            self.on_epoch_change(None)
            
    def on_epoch_change(self, change):
        run = self.run_dropdown.value
        ep = self.epoch_dropdown.value
        if not run or not ep: return
        subjs = get_subjects(run, ep)
        self.subj_dropdown.options = subjs
        if subjs:
            self.subj_dropdown.value = subjs[0]
            self.on_subj_change(None)
            
    def on_window_change(self, change):
        if self.window_dropdown.value == "custom":
            self.custom_box.layout.visibility = "visible"
        else:
            self.custom_box.layout.visibility = "hidden"
        self.update_plot(None)

    def load_data(self, subj_id, filename):
        # Try to find file in Train, Val, or Test split
        splits = ["train", "val", "test"]
        for split in splits:
            path = os.path.join(DATA_ROOT, split, subj_id, filename)
            if os.path.exists(path):
                return nib.load(path).get_fdata()
        return None

    def on_subj_change(self, change):
        run = self.run_dropdown.value
        ep = self.epoch_dropdown.value
        subj = self.subj_dropdown.value
        
        if not (run and ep and subj): return
        
        # Load Pred
        pred_path = os.path.join(PRED_ROOT, run, ep, f"pred_{subj}.nii.gz")
        if os.path.exists(pred_path):
            self.pred_vol = nib.load(pred_path).get_fdata()
        else:
            print(f"File not found: {pred_path}")
            return
            
        # Load GT and MRI
        self.gt_vol = self.load_data(subj, "ct.nii.gz")
        self.mri_vol = self.load_data(subj, "mr.nii.gz")
        
        self.update_max_slice(None)
        self.update_plot(None)

    def update_max_slice(self, change):
        if self.pred_vol is None: return
        axis = self.axis_dropdown.value
        max_s = self.pred_vol.shape[axis] - 1
        self.slice_slider.max = max_s
        if self.slice_slider.value > max_s:
            self.slice_slider.value = max_s // 2
        self.update_plot(None)

    def get_slice(self, vol, axis, sl):
        if vol is None: return None
        if axis == 0:
            return np.rot90(vol[sl, :, :])
        elif axis == 1:
            return np.rot90(vol[:, sl, :])
        else:
            return np.rot90(vol[:, :, sl])

    def update_plot(self, change):
        if self.pred_vol is None: return
        
        axis = self.axis_dropdown.value
        sl = self.slice_slider.value
        
        if self.window_dropdown.value == "custom":
            vmin, vmax = self.vmin_input.value, self.vmax_input.value
        else:
            vmin, vmax = self.window_dropdown.value
        
        show_mri = self.show_mri_check.value and (self.mri_vol is not None)
        show_diff = self.show_diff_check.value and (self.gt_vol is not None)
        
        img_pred = self.get_slice(self.pred_vol, axis, sl)
        img_gt = self.get_slice(self.gt_vol, axis, sl)
        img_mri = self.get_slice(self.mri_vol, axis, sl) if show_mri else None
        
        img_diff = None
        if show_diff:
            img_diff = img_gt - img_pred
            
        with self.output:
            clear_output(wait=True)
            
            plots = []
            if show_mri: plots.append(("MRI Input", img_mri, "gray", None, None))
            plots.append(("Ground Truth (HU)", img_gt, "gray", vmin, vmax))
            plots.append(("Prediction (HU)", img_pred, "gray", vmin, vmax))
            if show_diff: plots.append(("Residual (GT - Pred)", img_diff, "seismic", -500, 500))
            
            n_plots = len(plots)
            if n_plots == 0: return

            fig, ax = plt.subplots(1, n_plots, figsize=(5 * n_plots, 5))
            if n_plots == 1: ax = [ax]
            
            for i, (title, img, cmap, v_min, v_max) in enumerate(plots):
                if img is not None:
                    im = ax[i].imshow(img, cmap=cmap, vmin=v_min, vmax=v_max)
                    ax[i].set_title(title)
                    ax[i].axis("off")
                    if cmap == "seismic":
                         plt.colorbar(im, ax=ax[i], fraction=0.046, pad=0.04)
                else:
                    ax[i].text(0.5, 0.5, "Data Not Available", ha="center")
                    ax[i].axis("off")
                
            plt.suptitle(f"Slice {sl} | Window {vmin} to {vmax}")
            plt.tight_layout()
            plt.show()

In [21]:
viewer = NiftiViewer()

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Run:', options=('CNN_Train613_20260202_0144…

Output(outputs=({'output_type': 'display_data', 'data': {'text/plain': '<Figure size 1000x500 with 2 Axes>', '…