Demo Pipeline

Demo 1: DIPG test cases (anonymized DICOMs)
Demo 2: Validation examples (from PNG slices)

In [10]:
import os
import numpy as np
import torch

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from pathlib import Path
from PIL import Image
import pydicom
from scipy.ndimage import zoom
from IPython.display import HTML, display

import config
from model_pretrained import PedBrainNetPretrained

%matplotlib inline
device = torch.device('cuda')
CLASS_NAMES = config.CLASS_NAMES
os.makedirs('demos', exist_ok=True)

In [11]:
def load_model(path):
    model = PedBrainNetPretrained(num_classes=config.NUM_CLASSES, pretrained=False)
    if not os.path.exists(path):
        return None
    ckpt = torch.load(path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt)
    return model.to(device).eval()

cds_model = load_model('checkpoints/cds_best.pth')
fl_model = load_model('checkpoints/fl_best.pth')
print(f"CDS: {'loaded' if cds_model else 'not found'}")
print(f"FL: {'loaded' if fl_model else 'not found'}")

CDS: loaded
FL: loaded


## Demo 1: DIPG Test Cases (DICOMs)

In [12]:
DATA_DIR = Path('data/dipg_test')
CASE_NAMES = ['DIPG_Case_A', 'DIPG_Case_B', 'DIPG_Case_C']

def load_dicom_volume(dicom_dir):
    dcm_files = sorted(Path(dicom_dir).glob('*.dcm'))
    slices = []
    for f in dcm_files:
        ds = pydicom.dcmread(str(f))
        slices.append((ds.InstanceNumber, ds.pixel_array.astype(np.float32)))
    slices.sort(key=lambda x: x[0])
    volume = np.stack([s[1] for s in slices], axis=0)
    volume = (volume - volume.min()) / (volume.max() - volume.min() + 1e-7)
    return volume

def run_inference_dicom(volume, model):
    target = (64, 256, 256)
    zf = tuple(t/s for t, s in zip(target, volume.shape))
    vol_resized = zoom(volume, zf, order=1)
    x = torch.from_numpy(vol_resized).float().unsqueeze(0).unsqueeze(0).to(device)
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        seg, cls, _, _ = model(x)
    seg_np = seg.float().cpu().numpy()[0, 0]
    probs = torch.softmax(cls.float(), dim=1).cpu().numpy()[0]
    return vol_resized, seg_np, probs

# Load and process DIPG cases
dipg_cases = []
for name in CASE_NAMES:
    vol = load_dicom_volume(DATA_DIR / name)
    vol_resized, cds_seg, cds_probs = run_inference_dicom(vol, cds_model)
    _, fl_seg, fl_probs = run_inference_dicom(vol, fl_model)
    dipg_cases.append({
        'name': name, 'volume': vol_resized,
        'cds_seg': cds_seg, 'cds_probs': cds_probs,
        'fl_seg': fl_seg, 'fl_probs': fl_probs
    })
    print(f"{name}: CDS={CLASS_NAMES[cds_probs.argmax()]} ({cds_probs.max():.3f}), FL={CLASS_NAMES[fl_probs.argmax()]} ({fl_probs.max():.3f})")

DIPG_Case_A: CDS=DIPG (0.405), FL=DIPG (0.405)
DIPG_Case_B: CDS=DIPG (0.405), FL=DIPG (0.405)
DIPG_Case_C: CDS=DIPG (0.405), FL=DIPG (0.403)


In [None]:
# Display DIPG videos
for case in dipg_cases:
    video_path = f"demos/{case['name']}_cds_vs_fl.mp4"
    display(HTML(f"<h3>{case['name']} (GT: DIPG)</h3>"))
    display(HTML(f'<video width="800" controls><source src="{video_path}" type="video/mp4"></video>'))

In [13]:
def create_comparison_video(case, output_path, gt_label='DIPG'):
    """Create MP4 comparing CDS vs FL segmentation."""
    vol = case['volume']
    cds_seg, fl_seg = case['cds_seg'], case['fl_seg']
    cds_probs, fl_probs = case['cds_probs'], case['fl_probs']
    cds_pred = CLASS_NAMES[cds_probs.argmax()]
    fl_pred = CLASS_NAMES[fl_probs.argmax()]
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    def update(frame):
        for ax in axes:
            ax.clear()
        
        # Original
        axes[0].imshow(vol[frame], cmap='gray')
        axes[0].set_title(f'Original (slice {frame}/63)')
        axes[0].axis('off')
        
        # CDS
        axes[1].imshow(vol[frame], cmap='gray')
        if cds_seg[frame].max() > 0.1:
            overlay = np.zeros((*vol[frame].shape, 4))
            overlay[cds_seg[frame] > 0.5] = [0, 0, 1, 0.5]
            axes[1].imshow(overlay)
        axes[1].set_title(f'CDS: {cds_pred} ({cds_probs.max():.2f})')
        axes[1].axis('off')
        
        # FL
        axes[2].imshow(vol[frame], cmap='gray')
        if fl_seg[frame].max() > 0.1:
            overlay = np.zeros((*vol[frame].shape, 4))
            overlay[fl_seg[frame] > 0.5] = [1, 0, 0, 0.5]
            axes[2].imshow(overlay)
        axes[2].set_title(f'FL: {fl_pred} ({fl_probs.max():.2f})')
        axes[2].axis('off')
        
        plt.suptitle(f"{case['name']} - GT: {gt_label}", fontsize=12, fontweight='bold')
        return axes
    
    ani = animation.FuncAnimation(fig, update, frames=64, interval=100)
    ani.save(output_path, writer='ffmpeg', fps=10, dpi=100)
    plt.close()
    return output_path

# Generate DIPG demo videos
for case in dipg_cases:
    output_path = f"demos/{case['name']}_cds_vs_fl.mp4"
    create_comparison_video(case, output_path, gt_label='DIPG')
    print(f"Saved: {output_path}")

Saved: demos/DIPG_Case_A_cds_vs_fl.mp4
Saved: demos/DIPG_Case_B_cds_vs_fl.mp4
Saved: demos/DIPG_Case_C_cds_vs_fl.mp4


## Demo 2: Validation Examples (from PNG slices)

In [14]:
VAL_DIR = Path('data/val_examples')

def load_val_example(example_dir):
    """Load 64 PNG slices from a validation example directory."""
    example_dir = Path(example_dir)
    
    # Load metadata
    with open(example_dir / 'metadata.txt', 'r') as f:
        lines = f.readlines()
    metadata = {}
    for line in lines:
        key, val = line.strip().split(': ')
        metadata[key] = val
    
    # Load PNG slices
    slices = []
    for i in range(64):
        img_path = example_dir / f'slice_{i:02d}.png'
        img = np.array(Image.open(img_path).convert('L')).astype(np.float32) / 255.0
        slices.append(img)
    volume = np.stack(slices, axis=0)
    
    return volume, metadata

def run_inference_val(volume, model):
    """Run inference on a volume (D, H, W)."""
    # Resize to expected input size
    target = (64, 256, 256)
    if volume.shape != target:
        zf = tuple(t/s for t, s in zip(target, volume.shape))
        volume = zoom(volume, zf, order=1)
    
    x = torch.from_numpy(volume).float().unsqueeze(0).unsqueeze(0).to(device)
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        seg, cls, _, _ = model(x)
    seg_np = seg.float().cpu().numpy()[0, 0]
    probs = torch.softmax(cls.float(), dim=1).cpu().numpy()[0]
    return volume, seg_np, probs

# List validation examples
val_examples = sorted([d for d in VAL_DIR.iterdir() if d.is_dir()])
print(f"Found {len(val_examples)} validation examples:")
for ex in val_examples:
    print(f"  {ex.name}")

Found 8 validation examples:
  DIPG_03
  DIPG_04
  EPEN_01
  EPEN_02
  MEDU_05
  MEDU_06
  PILO_07
  PILO_08


In [15]:
# Process validation examples
val_cases = []
for example_dir in val_examples:
    volume, metadata = load_val_example(example_dir)
    gt_label = metadata['class_name']
    
    vol_resized, cds_seg, cds_probs = run_inference_val(volume, cds_model)
    _, fl_seg, fl_probs = run_inference_val(volume, fl_model)
    
    cds_pred = CLASS_NAMES[cds_probs.argmax()]
    fl_pred = CLASS_NAMES[fl_probs.argmax()]
    
    val_cases.append({
        'name': example_dir.name, 'gt_label': gt_label,
        'volume': vol_resized,
        'cds_seg': cds_seg, 'cds_probs': cds_probs,
        'fl_seg': fl_seg, 'fl_probs': fl_probs
    })
    
    cds_status = 'OK' if cds_pred == gt_label else 'X'
    fl_status = 'OK' if fl_pred == gt_label else 'X'
    print(f"{example_dir.name}: GT={gt_label}, CDS={cds_pred} [{cds_status}], FL={fl_pred} [{fl_status}]")

DIPG_03: GT=DIPG, CDS=DIPG [OK], FL=DIPG [OK]
DIPG_04: GT=DIPG, CDS=DIPG [OK], FL=DIPG [OK]
EPEN_01: GT=Ependymoma, CDS=Ependymoma [OK], FL=Ependymoma [OK]
EPEN_02: GT=Ependymoma, CDS=Ependymoma [OK], FL=Ependymoma [OK]
MEDU_05: GT=Medulloblastoma, CDS=Medulloblastoma [OK], FL=Medulloblastoma [OK]
MEDU_06: GT=Medulloblastoma, CDS=Medulloblastoma [OK], FL=Medulloblastoma [OK]
PILO_07: GT=Pilocytic, CDS=Pilocytic [OK], FL=Pilocytic [OK]
PILO_08: GT=Pilocytic, CDS=Pilocytic [OK], FL=Pilocytic [OK]


In [16]:
# Generate validation demo videos
for case in val_cases:
    output_path = f"demos/{case['name']}_cds_vs_fl.mp4"
    
    vol = case['volume']
    cds_seg, fl_seg = case['cds_seg'], case['fl_seg']
    cds_probs, fl_probs = case['cds_probs'], case['fl_probs']
    cds_pred = CLASS_NAMES[cds_probs.argmax()]
    fl_pred = CLASS_NAMES[fl_probs.argmax()]
    gt_label = case['gt_label']
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    def update(frame):
        for ax in axes:
            ax.clear()
        
        axes[0].imshow(vol[frame], cmap='gray')
        axes[0].set_title(f'Original (slice {frame}/63)')
        axes[0].axis('off')
        
        axes[1].imshow(vol[frame], cmap='gray')
        if cds_seg[frame].max() > 0.1:
            overlay = np.zeros((*vol[frame].shape, 4))
            overlay[cds_seg[frame] > 0.5] = [0, 0, 1, 0.5]
            axes[1].imshow(overlay)
        axes[1].set_title(f'CDS: {cds_pred} ({cds_probs.max():.2f})')
        axes[1].axis('off')
        
        axes[2].imshow(vol[frame], cmap='gray')
        if fl_seg[frame].max() > 0.1:
            overlay = np.zeros((*vol[frame].shape, 4))
            overlay[fl_seg[frame] > 0.5] = [1, 0, 0, 0.5]
            axes[2].imshow(overlay)
        axes[2].set_title(f'FL: {fl_pred} ({fl_probs.max():.2f})')
        axes[2].axis('off')
        
        plt.suptitle(f"{case['name']} - GT: {gt_label}", fontsize=12, fontweight='bold')
        return axes
    
    ani = animation.FuncAnimation(fig, update, frames=64, interval=100)
    ani.save(output_path, writer='ffmpeg', fps=10, dpi=100)
    plt.close()
    print(f"Saved: {output_path}")

Saved: demos/DIPG_03_cds_vs_fl.mp4
Saved: demos/DIPG_04_cds_vs_fl.mp4
Saved: demos/EPEN_01_cds_vs_fl.mp4
Saved: demos/EPEN_02_cds_vs_fl.mp4
Saved: demos/MEDU_05_cds_vs_fl.mp4
Saved: demos/MEDU_06_cds_vs_fl.mp4
Saved: demos/PILO_07_cds_vs_fl.mp4
Saved: demos/PILO_08_cds_vs_fl.mp4


In [17]:
# Display validation videos
for case in val_cases:
    video_path = f"demos/{case['name']}_cds_vs_fl.mp4"
    gt_label = case['gt_label']
    display(HTML(f"<h3>{case['name']} (GT: {gt_label})</h3>"))
    display(HTML(f'<video width="800" controls><source src="{video_path}" type="video/mp4"></video>'))

## Summary

In [18]:
print("="*70)
print("DIPG TEST CASES (GT: DIPG)")
print("="*70)
for case in dipg_cases:
    cds_pred = CLASS_NAMES[case['cds_probs'].argmax()]
    fl_pred = CLASS_NAMES[case['fl_probs'].argmax()]
    cds_ok = 'OK' if cds_pred == 'DIPG' else 'X'
    fl_ok = 'OK' if fl_pred == 'DIPG' else 'X'
    print(f"{case['name']}: CDS={cds_pred} [{cds_ok}], FL={fl_pred} [{fl_ok}]")

print("\n" + "="*70)
print("VALIDATION EXAMPLES")
print("="*70)
for case in val_cases:
    gt = case['gt_label']
    cds_pred = CLASS_NAMES[case['cds_probs'].argmax()]
    fl_pred = CLASS_NAMES[case['fl_probs'].argmax()]
    cds_ok = 'OK' if cds_pred == gt else 'X'
    fl_ok = 'OK' if fl_pred == gt else 'X'
    print(f"{case['name']}: GT={gt}, CDS={cds_pred} [{cds_ok}], FL={fl_pred} [{fl_ok}]")

DIPG TEST CASES (GT: DIPG)
DIPG_Case_A: CDS=DIPG [OK], FL=DIPG [OK]
DIPG_Case_B: CDS=DIPG [OK], FL=DIPG [OK]
DIPG_Case_C: CDS=DIPG [OK], FL=DIPG [OK]

VALIDATION EXAMPLES
DIPG_03: GT=DIPG, CDS=DIPG [OK], FL=DIPG [OK]
DIPG_04: GT=DIPG, CDS=DIPG [OK], FL=DIPG [OK]
EPEN_01: GT=Ependymoma, CDS=Ependymoma [OK], FL=Ependymoma [OK]
EPEN_02: GT=Ependymoma, CDS=Ependymoma [OK], FL=Ependymoma [OK]
MEDU_05: GT=Medulloblastoma, CDS=Medulloblastoma [OK], FL=Medulloblastoma [OK]
MEDU_06: GT=Medulloblastoma, CDS=Medulloblastoma [OK], FL=Medulloblastoma [OK]
PILO_07: GT=Pilocytic, CDS=Pilocytic [OK], FL=Pilocytic [OK]
PILO_08: GT=Pilocytic, CDS=Pilocytic [OK], FL=Pilocytic [OK]
