# Interactive Visualization of Brain Tumor Segmentation

Use the controls below to explore different cases and views:
- **Case**: select a base name (e.g., `BRATS_460`)
- **Plane**: axial, sagittal, or coronal
- **Slice**: slide through slice indices

Displays three rows:
1. Original MRI slice
2. MRI + predicted segmentation overlay
3. MRI + ground truth segmentation overlay

In [2]:

!pip install --upgrade pip
!pip install SimpleITK matplotlib
!pip install ipywidgets


Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jedi
Successfully installed jedi-0.19.2


In [4]:
from google.colab import drive
drive.mount('/content/drive')

#copy dataset to /content for faster access  on colab!
!mkdir /content/data
!cp -r "/content/drive/MyDrive/00-DataScience_BIU/Final Project/3D_UNet_Segmentation/3D_UNet_Brain_Tumor_Segmentation_multiclass_complete/inference_test/" /content/data/inference_test/


In [5]:
import os
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
from ipywidgets import interact, Dropdown, IntSlider, VBox, HBox, Output

# Directories
input_dir = '/content/data/inference_test/'
pred_dir  = '/content/data/inference_test/'
gt_dir    = '/content/data/inference_test/'  # update path if needed

# Collect case basenames
cases = [f.replace('_T1.nii.gz', '') for f in os.listdir(input_dir) if f.endswith('_T1.nii.gz')]
cases = sorted(cases)

def load_vol(dir_path, base, suffix):
    path = os.path.join(dir_path, f"{base}{suffix}")
    img = sitk.ReadImage(path)
    return sitk.GetArrayFromImage(img)

# Widgets
case_dd = Dropdown(options=cases, description='Case:')
plane_dd = Dropdown(options=['axial','sagittal','coronal'], value='axial', description='Plane:')
slice_slider = IntSlider(min=0, max=0, step=1, description='Slice:')
out = Output()

def update_slider(*args):
    base = case_dd.value
    vol = load_vol(input_dir, base, '_T1.nii.gz')
    max_idx = {'axial': vol.shape[0], 'coronal': vol.shape[1], 'sagittal': vol.shape[2]}[plane_dd.value]
    slice_slider.max = max_idx-1
    slice_slider.value = max_idx//2

case_dd.observe(update_slider, names='value')
plane_dd.observe(update_slider, names='value')
update_slider()

def plot(case, plane, idx):
    img = load_vol(input_dir, case, '_T1.nii.gz')
    pred = load_vol(pred_dir,  case, '_predict_seg.nii.gz')
    gt   = load_vol(gt_dir,    case, '.nii.gz')
    if plane=='axial':
        img_sl = img[idx,:,:]
        pred_sl= pred[idx,:,:]
        gt_sl  = gt[idx,:,:]
    elif plane=='sagittal':
        img_sl = img[:,:,idx]
        pred_sl= pred[:,:,idx]
        gt_sl  = gt[:,:,idx]
    else:
        img_sl = img[:,idx,:]
        pred_sl= pred[:,idx,:]
        gt_sl  = gt[:,idx,:]
    fig, axs = plt.subplots(3,1,figsize=(6,12))
    axs[0].imshow(img_sl, cmap='gray'); axs[0].set_title('MRI'); axs[0].axis('off')
    axs[1].imshow(img_sl, cmap='gray'); axs[1].imshow(pred_sl, cmap='jet', alpha=0.5)
    axs[1].set_title('Prediction'); axs[1].axis('off')
    axs[2].imshow(img_sl, cmap='gray'); axs[2].imshow(gt_sl, cmap='jet', alpha=0.5)
    axs[2].set_title('Ground Truth'); axs[2].axis('off')
    plt.tight_layout()
    plt.show()

def on_change(change):
    with out:
        out.clear_output()
        plot(case_dd.value, plane_dd.value, slice_slider.value)

case_dd.observe(on_change, names='value')
plane_dd.observe(on_change, names='value')
slice_slider.observe(on_change, names='value')

ui = VBox([HBox([case_dd, plane_dd]), slice_slider, out])
display(ui)
on_change(None)


VBox(children=(HBox(children=(Dropdown(description='Case:', options=('BRATS_460', 'BRATS_461', 'BRATS_462', 'B…