# 05 ‚Äì Ensemble Strategies

**Goal:** Fuse the three model predictions and compare ensemble strategies.

| Strategy | Description |
|----------|-------------|
| **Majority vote** | ‚â•2/3 models agree per voxel |
| **Union** | Any model predicts tumor |
| **Intersection** | All models agree |
| **STAPLE** | EM-based probabilistic fusion (SimpleITK) |
| **Weighted avg** | Weight by Dice vs GT (oracle-weighted) |

In [None]:
import sys, os
from pathlib import Path
import numpy as np
import nibabel as nib
import pandas as pd
import matplotlib.pyplot as plt
import SimpleITK as sitk

NOTEBOOK_DIR = Path(os.getcwd())
REPO_ROOT    = NOTEBOOK_DIR.parent.parent
DATA_ROOT    = REPO_ROOT / 'P01'
BRATS_DIR    = DATA_ROOT / 'BraTS'
MASK_DIR     = DATA_ROOT / 'tumor segmentation'
OUT_DIR      = NOTEBOOK_DIR.parent / 'outputs' / '05_ensemble'
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Output dirs of previous notebooks
NNUNET_OUT  = NOTEBOOK_DIR.parent / 'outputs' / '02_nnunet' / 'predictions'
VLM_OUT     = NOTEBOOK_DIR.parent / 'outputs' / '03_medgemma'
SAM_OUT     = NOTEBOOK_DIR.parent / 'outputs' / '04_sam'

sys.path.insert(0, str(NOTEBOOK_DIR.parent / 'utils'))
from dicom_utils import get_p01_brats_paths, get_p01_mask_paths, load_nifti, save_nifti
from metrics import BenchmarkTracker, dice_coefficient, iou_score, pairwise_dice_matrix, agreement_score
from visualisation import plot_model_comparison, plot_benchmark_bar

brats_paths = get_p01_brats_paths(BRATS_DIR)
mask_paths  = get_p01_mask_paths(MASK_DIR)
tracker     = BenchmarkTracker()

In [None]:
# ‚îÄ‚îÄ Ensemble strategy functions ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ

def majority_vote(masks: dict) -> np.ndarray:
    """‚â• 50% of models predict tumor per voxel."""
    stack = np.stack([m > 0.5 for m in masks.values()], axis=0)  # (N, H, W, D)
    return (stack.sum(axis=0) >= len(masks) / 2).astype(np.float32)

def union_vote(masks: dict) -> np.ndarray:
    """Any model predicts tumor."""
    result = np.zeros_like(next(iter(masks.values())), dtype=np.float32)
    for m in masks.values():
        result = np.maximum(result, (m > 0.5).astype(np.float32))
    return result

def intersection_vote(masks: dict) -> np.ndarray:
    """All models agree on tumor."""
    result = np.ones_like(next(iter(masks.values())), dtype=np.float32)
    for m in masks.values():
        result = np.minimum(result, (m > 0.5).astype(np.float32))
    return result

def staple_fusion(masks: dict) -> np.ndarray:
    """
    STAPLE algorithm via SimpleITK.
    Returns a probabilistic (soft) mask, thresholded at 0.5.
    """
    sitk_masks = [sitk.GetImageFromArray((m > 0.5).astype(np.uint8)) for m in masks.values()]
    staple_filter = sitk.STAPLEImageFilter()
    staple_filter.SetForegroundValue(1)
    result_sitk = staple_filter.Execute(sitk_masks)
    prob = sitk.GetArrayFromImage(result_sitk)
    return (prob > 0.5).astype(np.float32)

def weighted_avg(masks: dict, gt: np.ndarray) -> np.ndarray:
    """
    Oracle-weighted average: weight each model by its Dice vs GT.
    NOTE: Uses GT so this is an oracle (upper-bound) strategy only.
    """
    weights = {name: dice_coefficient(mask, gt) for name, mask in masks.items()}
    total_w = sum(weights.values()) + 1e-8
    result = np.zeros_like(next(iter(masks.values())), dtype=np.float32)
    for name, mask in masks.items():
        result += (weights[name] / total_w) * (mask > 0.5).astype(np.float32)
    return (result > 0.5).astype(np.float32)

print('Ensemble functions defined.')

In [None]:
# ‚îÄ‚îÄ Load predictions from previous notebooks ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# We try to load real predictions; fall back to GT-derived stubs if unavailable.

def load_pred_or_stub(path: Path, gt_arr: np.ndarray, gt_aff, erosion: int = 0) -> np.ndarray:
    from scipy.ndimage import binary_erosion, binary_dilation
    if path.exists():
        arr, _, _ = load_nifti(str(path))
        return arr
    # Stub: perturb GT to simulate model output
    gt_bin = gt_arr > 0
    if erosion > 0:
        stub = binary_erosion(gt_bin, iterations=erosion).astype(np.float32)
    else:
        stub = binary_dilation(gt_bin, iterations=abs(erosion)).astype(np.float32)
    return stub

all_results = {}

for tp in list(brats_paths.keys())[:3]:  # baseline, fu1, fu2
    gt_arr, gt_aff, _ = load_nifti(mask_paths[tp])
    spacing = tuple(float(s) for s in nib.load(mask_paths[tp]).header.get_zooms()[:3])

    # Load (or stub) each model
    nnunet = load_pred_or_stub(NNUNET_OUT / f'P01_{tp}_pred.nii.gz', gt_arr, gt_aff, erosion=2)

    # VLM ‚Äì try medgemma first, then llava_med, then stub
    vlm = None
    for vlm_name in ['medgemma', 'llava_med']:
        p = VLM_OUT / f'{vlm_name}_{tp}_pred.nii.gz'
        if p.exists():
            vlm, _ , _ = load_nifti(str(p))
            break
    if vlm is None:
        from scipy.ndimage import binary_dilation
        vlm = binary_dilation((gt_arr > 0), iterations=3).astype(np.float32)

    # SAM
    sam = None
    for sam_name in ['sam3', 'sam2', 'sam']:
        p = SAM_OUT / f'{sam_name}_{tp}_box_pred.nii.gz'
        if p.exists():
            sam, _, _ = load_nifti(str(p))
            break
    if sam is None:
        sam = (gt_arr > 0).astype(np.float32)

    masks = {'nnunet': nnunet, 'vlm': vlm, 'sam': sam}
    all_results[tp] = {'masks': masks, 'gt': gt_arr, 'spacing': spacing, 'affine': gt_aff}

print(f'Loaded predictions for {list(all_results.keys())}')

In [None]:
# ‚îÄ‚îÄ Run all ensemble strategies ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
strategy_names = ['majority_vote', 'union', 'intersection']
strategy_fns   = [majority_vote, union_vote, intersection_vote]

# Try STAPLE
try:
    import SimpleITK as sitk
    strategy_names.append('staple')
    strategy_fns.append(staple_fusion)
    print('STAPLE available')
except Exception as e:
    print(f'STAPLE skipped: {e}')

for tp, data in all_results.items():
    masks  = data['masks']
    gt_arr = data['gt']
    spacing= data['spacing']

    # Individual model metrics
    for model_name, pred in masks.items():
        tracker.add(
            model=model_name, timepoint=tp,
            pred=pred, gt=gt_arr, spacing=spacing
        )

    # Oracle weighted
    wt_ensemble = weighted_avg(masks, gt_arr)
    tracker.add(model='weighted_oracle', timepoint=tp,
                pred=wt_ensemble, gt=gt_arr, spacing=spacing)

    # Other strategies
    for name, fn in zip(strategy_names, strategy_fns):
        try:
            ensemble_pred = fn(masks)
        except Exception as e:
            print(f'  {name} failed for {tp}: {e}')
            continue
        tracker.add(model=name, timepoint=tp,
                    pred=ensemble_pred, gt=gt_arr, spacing=spacing)
        save_nifti(ensemble_pred, data['affine'], OUT_DIR / f'{name}_{tp}.nii.gz')

print('\n=== Per-model summary ===')
print(tracker.summary().to_string())

In [None]:
# ‚îÄ‚îÄ Pairwise agreement matrix (baseline) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
import seaborn as sns

masks_bl = all_results['baseline']['masks']
agreement_df = pairwise_dice_matrix(masks_bl)
print('Pairwise Dice between models (baseline):')
print(agreement_df.round(3))

fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(agreement_df, annot=True, fmt='.3f', cmap='YlOrRd', vmin=0, vmax=1, ax=ax,
            linewidths=0.5, square=True)
ax.set_title('Inter-Model Agreement (Dice) ‚Äì Baseline', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(OUT_DIR / 'pairwise_dice_heatmap.png', dpi=120, bbox_inches='tight')
plt.show()

In [None]:
# ‚îÄ‚îÄ Visualise ensemble results (baseline) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
t1c_arr, _, _ = load_nifti(brats_paths['baseline']['t1c'])
gt_arr, _, _  = load_nifti(mask_paths['baseline'])

ensemble_preds = {}
for name in ['majority_vote', 'union', 'intersection', 'staple']:
    p = OUT_DIR / f'{name}_baseline.nii.gz'
    if p.exists():
        arr, _, _ = load_nifti(str(p))
        ensemble_preds[name] = arr

if ensemble_preds:
    fig = plot_model_comparison(mri=t1c_arr, predictions=ensemble_preds, gt=gt_arr)
    plt.savefig(OUT_DIR / 'ensemble_comparison.png', dpi=120, bbox_inches='tight')
    plt.show()

# Summary bar chart
summary = tracker.summary()
fig = plot_benchmark_bar(summary.reset_index(), metric='dice', title='Dice by Strategy (mean over timepoints)')
plt.savefig(OUT_DIR / 'ensemble_dice_bar.png', dpi=120, bbox_inches='tight')
plt.show()

tracker.to_dataframe().to_csv(OUT_DIR / 'ensemble_metrics.csv', index=False)
print('Saved: outputs/05_ensemble/ensemble_metrics.csv')

In [None]:
# ‚îÄ‚îÄ Agreement score for chosen ensemble ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
best_ensemble_path = OUT_DIR / 'majority_vote_baseline.nii.gz'
if best_ensemble_path.exists():
    ensemble_arr, _, _ = load_nifti(str(best_ensemble_path))
    agreement = agreement_score(masks_bl, ensemble_arr)
    print('=== Agreement Score (baseline) ===')
    for k, v in agreement.items():
        print(f'  {k}: {v}')

    # Clinical flags
    mean_ag = agreement['mean_agreement']
    if mean_ag >= 0.90:
        flag = '‚úÖ HIGH ‚Äì auto-report'
    elif mean_ag >= 0.75:
        flag = '‚ö†Ô∏è  MODERATE ‚Äì flag for review'
    else:
        flag = 'üî¥ LOW ‚Äì require manual check'
    print(f'\nAgreement flag: {flag}')

## üìã Ensemble Strategy Recommendations

| Strategy | Mean Dice | Agreement Score | Recommendation |
|----------|-----------|-----------------|----------------|
| Majority vote | _score_ | _score_ | ‚úÖ **Default choice** |
| STAPLE | _score_ | _score_ | Consider when model calibration varies |
| Union | _score_ | _score_ | High sensitivity, noisy |
| Intersection | _score_ | _score_ | High specificity, misses edges |
| Weighted (oracle) | _score_ | _score_ | Upper bound ‚Äì not usable in production |