In [11]:
from pathlib import Path

# SET YOUR DATA PATH HERE
DATA_PATH = Path("../data/BraTS-2023/BraTS-GLI-00008-001/")  # Change this to your case directory

# Verify path exists
if not DATA_PATH.exists():
    raise FileNotFoundError(f"Data path does not exist: {DATA_PATH}")
print(f"Loading case from: {DATA_PATH}")

In [12]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, CheckButtons
from IPython.display import display
import ipywidgets as widgets

# Use widget backend for interactive plots
%matplotlib widget

# Color maps for segmentation (matching your LUT)
SEG_COLORS = {
    0: [0.0, 0.0, 0.0, 0.0],      # Background (transparent)
    1: [0.0, 0.4, 1.0, 0.9],      # NCR/NET (blue)
    2: [0.0, 0.8, 0.0, 0.7],      # Edema (green)
    4: [1.0, 0.1, 0.1, 0.9],      # Enhancing tumor (red)
}

MOD_SUFFIXES = {
    "t1n": "T1n",
    "t1c": "T1c",
    "t2w": "T2w",
    "t2f": "FLAIR",
}

print("Imports complete")


In [13]:
def load_nifti_normalized(path):
    """Load and normalize NIfTI to [0, 1] range"""
    img = nib.load(str(path))
    data = img.get_fdata(dtype=np.float32)
    
    # Robust normalization (1st to 99.5th percentile)
    vmin = np.percentile(data, 1.0)
    vmax = np.percentile(data, 99.5)
    
    if vmax <= vmin:
        vmax = np.max(data)
        vmin = np.min(data)
    
    rng = max(1e-6, vmax - vmin)
    normalized = np.clip((data - vmin) / rng, 0.0, 1.0)
    
    return normalized, img.header

def load_segmentation(path):
    """Load segmentation labels as integers"""
    img = nib.load(str(path))
    data = img.get_fdata(dtype=np.float32)
    labels = np.rint(data).astype(np.uint8)
    return labels, img.header

def find_modality_files(case_dir):
    """Find all modality and segmentation files in directory"""
    files = list(case_dir.glob("*.nii.gz"))
    
    mod_files = {}
    seg_file = None
    
    for f in files:
        name = f.name.lower()
        
        if name.endswith("-seg.nii.gz"):
            seg_file = f
            continue
        
        for suffix, key in MOD_SUFFIXES.items():
            if name.endswith(f"-{suffix}.nii.gz"):
                mod_files[key] = f
                break
    
    return mod_files, seg_file

print("Data loading functions defined")

In [14]:
print("Loading volumes...")

# Find files
mod_files, seg_file = find_modality_files(DATA_PATH)

if not mod_files:
    raise RuntimeError(f"No modality volumes found in {DATA_PATH}")

# Load all modalities
volumes = {}
for key, path in mod_files.items():
    print(f"  Loading {key}: {path.name}")
    vol, hdr = load_nifti_normalized(path)
    volumes[key] = vol

# Load segmentation if available
seg_volume = None
if seg_file is not None:
    print(f"  Loading segmentation: {seg_file.name}")
    seg_volume, seg_hdr = load_segmentation(seg_file)
else:
    print("  No segmentation file found")

# Get dimensions
first_vol = next(iter(volumes.values()))
dims = first_vol.shape
print(f"\nVolume dimensions: {dims[0]} x {dims[1]} x {dims[2]}")
print(f"Loaded modalities: {list(volumes.keys())}")


In [15]:
def create_seg_overlay(seg_slice, grayscale=False):
    """Convert segmentation slice to RGBA overlay"""
    h, w = seg_slice.shape
    overlay = np.zeros((h, w, 4), dtype=np.float32)
    
    for label, color in SEG_COLORS.items():
        mask = seg_slice == label
        if grayscale and label > 0:
            overlay[mask] = [1.0, 1.0, 1.0, 0.9]
        else:
            overlay[mask] = color
    
    return overlay

def blend_with_overlay(intensity_slice, seg_slice, show_seg=True, grayscale_seg=False):
    """Blend intensity image with segmentation overlay"""
    # Convert intensity to RGB
    rgb = np.stack([intensity_slice] * 3, axis=-1)
    
    if show_seg and seg_slice is not None:
        overlay = create_seg_overlay(seg_slice, grayscale=grayscale_seg)
        alpha = overlay[..., 3:4]
        rgb = rgb * (1 - alpha) + overlay[..., :3] * alpha
    
    return np.clip(rgb, 0, 1)

print("Overlay functions defined")


In [16]:
class BraTSSliceViewer:
    def __init__(self, volumes, seg_volume=None):
        self.volumes = volumes
        self.seg_volume = seg_volume
        self.dims = next(iter(volumes.values())).shape
        
        # Default settings
        self.current_axis = 2  # axial by default
        self.current_slice = self.dims[self.current_axis] // 2
        self.active_modalities = list(volumes.keys())
        self.weights = {k: 1.0 for k in volumes.keys()}
        self.window_width = 1.0
        self.window_level = 0.5
        self.show_seg = True
        self.grayscale_seg = False
        
        self._create_widgets()
    
    def _create_widgets(self):
        """Create interactive widgets"""
        # Slice slider
        self.slice_slider = widgets.IntSlider(
            value=self.current_slice,
            min=0,
            max=self.dims[self.current_axis] - 1,
            description='Slice:',
            continuous_update=True
        )
        
        # Axis selector
        self.axis_dropdown = widgets.Dropdown(
            options=[('Axial (Z)', 2), ('Coronal (Y)', 1), ('Sagittal (X)', 0)],
            value=2,
            description='View:'
        )
        
        # Modality checkboxes
        self.mod_checks = {}
        for mod in self.volumes.keys():
            self.mod_checks[mod] = widgets.Checkbox(
                value=True,
                description=mod,
                indent=False
            )
        
        # Modality weights
        self.mod_weights = {}
        for mod in self.volumes.keys():
            self.mod_weights[mod] = widgets.FloatSlider(
                value=1.0,
                min=0.0,
                max=2.0,
                step=0.1,
                description=f'{mod} weight:',
                continuous_update=True
            )
        
        # Window/Level
        self.ww_slider = widgets.FloatSlider(
            value=1.0,
            min=0.01,
            max=2.0,
            step=0.01,
            description='Window Width:',
            continuous_update=True
        )
        
        self.wl_slider = widgets.FloatSlider(
            value=0.5,
            min=0.0,
            max=1.0,
            step=0.01,
            description='Window Level:',
            continuous_update=True
        )
        
        # Segmentation controls
        self.seg_check = widgets.Checkbox(
            value=self.show_seg,
            description='Show Segmentation',
            disabled=(self.seg_volume is None)
        )
        
        self.seg_gray_check = widgets.Checkbox(
            value=self.grayscale_seg,
            description='Grayscale Seg',
            disabled=(self.seg_volume is None)
        )
        
        # Connect callbacks
        self.slice_slider.observe(self._on_change, 'value')
        self.axis_dropdown.observe(self._on_axis_change, 'value')
        self.ww_slider.observe(self._on_change, 'value')
        self.wl_slider.observe(self._on_change, 'value')
        self.seg_check.observe(self._on_change, 'value')
        self.seg_gray_check.observe(self._on_change, 'value')
        
        for check in self.mod_checks.values():
            check.observe(self._on_change, 'value')
        for weight in self.mod_weights.values():
            weight.observe(self._on_change, 'value')
    
    def _on_axis_change(self, change):
        """Handle axis change"""
        self.current_axis = change['new']
        self.slice_slider.max = self.dims[self.current_axis] - 1
        self.slice_slider.value = self.dims[self.current_axis] // 2
    
    def _on_change(self, change):
        """Handle any parameter change"""
        self.update_display()
    
    def _get_slice(self, volume, axis, idx):
        """Extract 2D slice from 3D volume"""
        if axis == 0:  # sagittal
            return volume[idx, :, :].T
        elif axis == 1:  # coronal
            return volume[:, idx, :].T
        else:  # axial
            return volume[:, :, idx].T
    
    def _apply_window_level(self, image):
        """Apply window/level adjustment"""
        ww = self.ww_slider.value
        wl = self.wl_slider.value
        
        vmin = wl - ww / 2
        vmax = wl + ww / 2
        
        return np.clip((image - vmin) / max(ww, 1e-6), 0, 1)
    
    def update_display(self):
        """Update the displayed image"""
        slice_idx = self.slice_slider.value
        axis = self.axis_dropdown.value
        
        # Combine enabled modalities
        combined = None
        total_weight = 0.0
        
        for mod_key, volume in self.volumes.items():
            if self.mod_checks[mod_key].value:
                weight = self.mod_weights[mod_key].value
                slice_data = self._get_slice(volume, axis, slice_idx)
                
                if combined is None:
                    combined = slice_data * weight
                else:
                    combined += slice_data * weight
                total_weight += weight
        
        if combined is None or total_weight == 0:
            combined = np.zeros(self._get_slice(next(iter(self.volumes.values())), axis, slice_idx).shape)
        else:
            combined /= total_weight
        
        # Apply window/level
        combined = self._apply_window_level(combined)
        
        # Get segmentation slice if available
        seg_slice = None
        if self.seg_volume is not None and self.seg_check.value:
            seg_slice = self._get_slice(self.seg_volume, axis, slice_idx)
        
        # Blend with overlay
        display_image = blend_with_overlay(
            combined, 
            seg_slice,
            show_seg=self.seg_check.value,
            grayscale_seg=self.seg_gray_check.value
        )
        
        # Update plot
        self.ax.clear()
        self.ax.imshow(display_image, origin='lower', interpolation='bilinear')
        self.ax.set_title(f"Slice {slice_idx} / {self.dims[axis]-1}")
        self.ax.axis('off')
        self.fig.canvas.draw_idle()
    
    def show(self):
        """Display the interactive viewer"""
        # Create figure
        self.fig, self.ax = plt.subplots(figsize=(10, 10))
        plt.subplots_adjust(left=0.1, bottom=0.35)
        
        # Layout widgets
        controls = widgets.VBox([
            widgets.HBox([self.axis_dropdown, self.slice_slider]),
            widgets.Label("Modalities:"),
            widgets.HBox([self.mod_checks[k] for k in self.volumes.keys()]),
            widgets.Label("Weights:"),
            *[self.mod_weights[k] for k in self.volumes.keys()],
            self.ww_slider,
            self.wl_slider,
            widgets.Label("Segmentation:"),
            widgets.HBox([self.seg_check, self.seg_gray_check])
        ])
        
        # Initial display
        self.update_display()
        
        display(controls)
        plt.show()

print("Viewer class defined")

In [17]:
print("Launching interactive viewer...")
viewer = BraTSSliceViewer(volumes, seg_volume)
viewer.show()


In [18]:
def show_multi_view(volumes, seg_volume=None, slice_idx=None):
    """Show all modalities side-by-side for a specific slice"""
    n_mods = len(volumes)
    fig, axes = plt.subplots(1, n_mods + (1 if seg_volume is not None else 0), 
                             figsize=(4 * (n_mods + 1), 4))
    
    if slice_idx is None:
        slice_idx = next(iter(volumes.values())).shape[2] // 2
    
    for idx, (mod_name, vol) in enumerate(volumes.items()):
        axes[idx].imshow(vol[:, :, slice_idx].T, cmap='gray', origin='lower')
        axes[idx].set_title(mod_name)
        axes[idx].axis('off')
    
    if seg_volume is not None:
        seg_rgb = np.zeros((*seg_volume.shape[:2], 3))
        seg_slice = seg_volume[:, :, slice_idx].T
        
        # Color code the segmentation
        seg_rgb[seg_slice == 1] = [0.0, 0.4, 1.0]  # NCR/NET blue
        seg_rgb[seg_slice == 2] = [0.0, 0.8, 0.0]  # Edema green
        seg_rgb[seg_slice == 4] = [1.0, 0.1, 0.1]  # Enhancing red
        
        axes[-1].imshow(seg_rgb, origin='lower')
        axes[-1].set_title('Segmentation')
        axes[-1].axis('off')
    
    plt.tight_layout()
    plt.show()


In [19]:
show_multi_view(volumes, seg_volume, slice_idx=dims[2]//2)