# 🧠 **Clinical-Grade Multi-Site Brain Tumor Segmentation with AI**
---
<span style="color:red">*by **Ridwan Oladipo, MD | AI Specialist***</span>  

A **state-of-the-art, radiologist-grade nnU-Net v2 pipeline** for automated segmentation of **Whole Tumor (WT)**, **Tumor Core (TC)**, and **Enhancing Tumor (ET)** from multi-modal MRI, targeting **WT Dice ≥ 90%** and **BraTS Avg ≥ 80%**.

Built to **harmonize multi-institutional MRI data**, preserve tumor anatomy with high fidelity, and empower **neurosurgeons, oncologists, and radiologists** in treatment planning, surgical navigation, and disease monitoring.

---

## **🔬 Project Scope**
- 🧠 **Multi-site MRI harmonization** (N4 bias correction + z-score normalization)  
- 📏 **Dataset fingerprinting**: spacings, shapes, intensity stats, tumor volumes  
- ⚙️ **Optimized brain extraction** (47.5% coverage) with tumor preservation  
- 🤖 Professional nnU-Net v2 preprocessing + clinical metrics dashboard  
- 📊 **BraTS-standard evaluation metrics** & automated tumor volume quantification  
- 🚀 Deployment-ready: **FastAPI + ONNX** inference on AWS ECS  

---

## **📂 Dataset Summary**
- ~750 3D MRI volumes (FLAIR, T1w, T1Gd, T2w)  
- Glioma types: High-Grade (HGG) + Low-Grade (LGG)  
- Labels: **0 = background, 1 = edema, 2 = non-enhancing tumor, 3 = enhancing tumor**  

---

## **💡 Why This Matters**  
Automated segmentation accelerates precision oncology — transforming hours of manual delineation into **seconds of AI-powered analysis**, enabling accurate tumor tracking, improved surgical planning, and early recurrence detection.

---
> ⚕️ **Created by a medical doctor + AI expert, <span style="color:red"><b>Ridwan Oladipo</b></span>, merging clinical neuroimaging insight with cutting-edge deep learning to advance brain tumor care.**

# 🚀 Ultra-Optimized Brain Tumor AI Environment Setup

## 🧩 Core & Medical Imaging Libraries

In [1]:
# Core libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import json
import nibabel as nib
from tqdm.notebook import tqdm
import warnings

warnings.filterwarnings('ignore')

# Medical imaging libraries
import SimpleITK as sitk
import pydicom
from scipy import ndimage
from skimage import measure
from sklearn.model_selection import KFold

## 🚀 Environment & Dataset Configuration

In [2]:
# Visualization style
plt.style.use('seaborn-v0_8-whitegrid')

# Dataset configuration
USE_FULL_DATASET = False  # Set True for full training
MAX_VOLUMES = 4 if not USE_FULL_DATASET else None
RANDOM_SEED = 42

# Create directories
base_dir = Path('/kaggle/working/brain_tumor_segmentation')
data_dir = Path('/kaggle/input/braintumor')
processed_dir = base_dir / 'processed'
nnunet_dir = base_dir / 'nnUNet_raw' / 'Dataset001_BrainTumor'
dicom_dir = base_dir / 'dicom_converted'
results_dir = base_dir / 'results'
visualization_dir = base_dir / 'visualizations'

for directory in [base_dir, processed_dir, nnunet_dir, nnunet_dir / 'imagesTr',
                  nnunet_dir / 'labelsTr', nnunet_dir / 'imagesTs', dicom_dir, results_dir, visualization_dir]:
    directory.mkdir(exist_ok=True, parents=True)

print(f"✅ Environment setup - Using {'FULL' if USE_FULL_DATASET else MAX_VOLUMES} volumes")
print(f"🎯 Target metrics: WT Dice ≥ 90%, BraTS Avg ≥ 80%")

✅ Environment setup - Using 4 volumes
🎯 Target metrics: WT Dice ≥ 90%, BraTS Avg ≥ 80%


# 📊 Dataset Metadata & DICOM Support

## 📄 Load Dataset Metadata and NIfTI Files

In [4]:
# Load dataset metadata
with open(data_dir / 'dataset.json', 'r') as f:
    dataset_metadata = json.load(f)

print(f"📊 DATASET INFORMATION:")
print(f"Name: {dataset_metadata['name']}")
print(f"Description: {dataset_metadata['description']}")
print(f"Training cases: {dataset_metadata['numTraining']}")
print(f"Modalities: {', '.join([f'{k}: {v}' for k, v in dataset_metadata['modality'].items()])}")
print(f"Labels: {', '.join([f'{k}: {v}' for k, v in dataset_metadata['labels'].items()])}")

# Collect files - filter hidden files
train_image_files = sorted([f for f in (data_dir / 'imagesTr').glob('*.nii*') if not f.name.startswith('._')])
train_label_files = sorted([f for f in (data_dir / 'labelsTr').glob('*.nii*') if not f.name.startswith('._')])
test_image_files = sorted([f for f in (data_dir / 'imagesTs').glob('*.nii*') if not f.name.startswith('._')])

print(f"Found {len(train_image_files)} training images")
print(f"Found {len(train_label_files)} training labels")
print(f"Found {len(test_image_files)} test images")

# Apply volume limitation for resource constraints
if not USE_FULL_DATASET and MAX_VOLUMES:
    np.random.seed(RANDOM_SEED)
    indices = np.random.choice(len(train_image_files), min(MAX_VOLUMES, len(train_image_files)), replace=False)
    train_image_files = [train_image_files[i] for i in sorted(indices)]
    train_label_files = [train_label_files[i] for i in sorted(indices)]
    print(f"Limited to {len(train_image_files)} volumes for current run")

📊 DATASET INFORMATION:
Name: BRATS
Description: Gliomas segmentation tumour and oedema in on brain images
Training cases: 484
Modalities: 0: FLAIR, 1: T1w, 2: t1gd, 3: T2w
Labels: 0: background, 1: edema, 2: non-enhancing tumor, 3: enhancing tumour
Found 484 training images
Found 484 training labels
Found 266 test images
Limited to 4 volumes for current run


## 🏥 DICOM Discovery & Conversion

In [None]:
print("🏥 DICOM SUPPORT:")

# Search for DICOM directories recursively
all_dicom_dirs = []
for root, dirs, files in os.walk(data_dir):
    for file in files:
        if file.lower().endswith(('.dcm', '.dicom')) or not '.' in file:
            dicom_path = Path(root)
            if dicom_path not in all_dicom_dirs:
                all_dicom_dirs.append(dicom_path)

# Check for standard DICOM directory structures
dicom_patterns = ['**/DICOM*', '**/dicom*', '**/DCM*', '**/dcm*']
for pattern in dicom_patterns:
    found_dirs = list(data_dir.glob(pattern))
    all_dicom_dirs.extend([d for d in found_dirs if d not in all_dicom_dirs])

if all_dicom_dirs:
    print(f"📁 Found {len(all_dicom_dirs)} DICOM directories")

    # Process each DICOM directory with metadata extraction
    for dicom_path in all_dicom_dirs:
        print(f"Processing DICOM directory: {dicom_path}")

        try:
            # Get all DICOM files in directory
            dcm_files = []
            for ext in ['*.dcm', '*.dicom', '*']:
                dcm_files.extend(list(dicom_path.glob(ext)))

            # Filter actual DICOM files
            valid_dcm_files = []
            for dcm_file in dcm_files:
                try:
                    ds = pydicom.dcmread(dcm_file, force=True)
                    if hasattr(ds, 'PatientID'):
                        valid_dcm_files.append(str(dcm_file))
                except:
                    continue

            if valid_dcm_files:
                # Read DICOM series with SimpleITK
                reader = sitk.ImageSeriesReader()

                # Get series UIDs
                series_ids = reader.GetGDCMSeriesIDs(str(dicom_path))

                for series_id in series_ids:
                    # Get file names for this series
                    dicom_names = reader.GetGDCMSeriesFileNames(str(dicom_path), series_id)

                    if dicom_names:
                        reader.SetFileNames(dicom_names)

                        # Extract metadata from first DICOM
                        sample_ds = pydicom.dcmread(dicom_names[0], force=True)

                        # Extract comprehensive metadata
                        patient_id = getattr(sample_ds, 'PatientID', 'Unknown')
                        study_date = getattr(sample_ds, 'StudyDate', 'Unknown')
                        series_desc = getattr(sample_ds, 'SeriesDescription', 'Unknown')
                        modality = getattr(sample_ds, 'Modality', 'Unknown')
                        slice_thickness = getattr(sample_ds, 'SliceThickness', 'Unknown')
                        pixel_spacing = getattr(sample_ds, 'PixelSpacing', [1.0, 1.0])

                        print(f"  Series: {series_desc} ({modality})")
                        print(f"  Patient: {patient_id}, Date: {study_date}")
                        print(f"  Slice thickness: {slice_thickness}, Pixel spacing: {pixel_spacing}")

                        # Read the image
                        image = reader.Execute()

                        # Convert to proper orientation (RAS)
                        image = sitk.DICOMOrient(image, 'RAS')

                        # Save as NIfTI with metadata preservation
                        output_name = f"{patient_id}_{series_id}_{modality}.nii.gz"
                        output_path = dicom_dir / output_name

                        sitk.WriteImage(image, str(output_path))

                        # Save metadata
                        metadata = {
                            'patient_id': patient_id,
                            'study_date': study_date,
                            'series_description': series_desc,
                            'modality': modality,
                            'slice_thickness': str(slice_thickness),
                            'pixel_spacing': [float(x) for x in pixel_spacing],
                            'original_dicom_path': str(dicom_path),
                            'series_id': series_id,
                            'num_slices': len(dicom_names)
                        }

                        with open(dicom_dir / f"{output_name}_metadata.json", 'w') as f:
                            json.dump(metadata, f, indent=2)

                        print(f"  ✅ Converted: {output_name}")

        except Exception as e:
            print(f"  ❌ Error processing {dicom_path}: {e}")

    print(f"✅ DICOM conversion complete - {len(list(dicom_dir.glob('*.nii.gz')))} files converted")
else:
    print("📁 No DICOM directories found - using NIfTI files directly")

# 🔬 nnU-Net Dataset Fingerprinting & Target Spacing Calibration

## 🧠 nnU-Net Dataset Fingerprinting

In [None]:
print("🔬 nnU-Net DATASET FINGERPRINTING:")

# Analyze dataset properties for nnU-Net planning
dataset_properties = {
    'spacings': [],
    'shapes': [],
    'modalities': ['FLAIR', 'T1w', 'T1Gd', 'T2w'],
    'intensity_properties': {mod: {'percentiles': [], 'mean': [], 'std': []} for mod in
                             ['FLAIR', 'T1w', 'T1Gd', 'T2w']},
    'label_properties': {'labels': [], 'volumes': []}
}

# Extract comprehensive dataset fingerprint
for img_path, label_path in tqdm(zip(train_image_files, train_label_files),
                                 total=len(train_image_files), desc="Fingerprinting dataset"):

    # Load volumes
    img_nifti = nib.load(img_path)
    label_nifti = nib.load(label_path)

    img_data = img_nifti.get_fdata()
    label_data = label_nifti.get_fdata()
    spacing = img_nifti.header.get_zooms()[:3]

    # Collect spacing and shape information
    dataset_properties['spacings'].append(spacing)
    dataset_properties['shapes'].append(img_data.shape[:3])

    # Analyze intensity properties per modality
    for mod_idx, modality in enumerate(dataset_properties['modalities']):
        mod_data = img_data[..., mod_idx]

        # Create brain mask using Otsu thresholding
        sitk_img = sitk.GetImageFromArray(mod_data.astype(np.float32))
        otsu_filter = sitk.OtsuThresholdImageFilter()
        otsu_filter.SetInsideValue(0)
        otsu_filter.SetOutsideValue(1)
        brain_mask_sitk = otsu_filter.Execute(sitk_img)
        brain_mask = sitk.GetArrayFromImage(brain_mask_sitk) > 0

        brain_voxels = mod_data[brain_mask]

        if len(brain_voxels) > 1000:  # Ensure sufficient voxels
            # nnU-Net style percentile analysis
            percentiles = np.percentile(brain_voxels, [0.5, 10, 50, 90, 99.5])
            dataset_properties['intensity_properties'][modality]['percentiles'].append(percentiles)
            dataset_properties['intensity_properties'][modality]['mean'].append(np.mean(brain_voxels))
            dataset_properties['intensity_properties'][modality]['std'].append(np.std(brain_voxels))

    # Analyze label properties
    unique_labels = np.unique(label_data)
    dataset_properties['label_properties']['labels'].append(unique_labels)

    # Calculate label volumes
    voxel_volume = np.prod(spacing)
    label_volumes = {}
    for label_val in unique_labels:
        if label_val > 0:  # Skip background
            volume = np.sum(label_data == label_val) * voxel_volume / 1000  # Convert to cm³
            label_volumes[int(label_val)] = volume
    dataset_properties['label_properties']['volumes'].append(label_volumes)

## 📐 Target Spacing & Intensity Normalization Calibration

In [None]:
# Calculate nnU-Net target properties
spacings_array = np.array(dataset_properties['spacings'])
shapes_array = np.array(dataset_properties['shapes'])

# nnU-Net spacing determination - median spacing
median_spacing = np.median(spacings_array, axis=0)
target_spacing = [float(x) for x in median_spacing]

# Calculate intensity normalization parameters per modality
normalization_params = {}
for modality in dataset_properties['modalities']:
    all_percentiles = np.array(dataset_properties['intensity_properties'][modality]['percentiles'])
    all_means = np.array(dataset_properties['intensity_properties'][modality]['mean'])
    all_stds = np.array(dataset_properties['intensity_properties'][modality]['std'])

    if len(all_percentiles) > 0:
        # nnU-Net normalization scheme
        normalization_params[modality] = {
            'clip_lower': float(np.median(all_percentiles[:, 0])),  # 0.5th percentile
            'clip_upper': float(np.median(all_percentiles[:, 4])),  # 99.5th percentile
            'mean_intensity': float(np.median(all_means)),
            'std_intensity': float(np.median(all_stds))
        }

print(f"✅ Dataset fingerprinting complete")
print(f"📏 Target spacing: {target_spacing} mm")
print(f"🎯 Modalities analyzed: {len(dataset_properties['modalities'])}")

# Save dataset fingerprint
fingerprint = {
    'target_spacing': target_spacing,
    'normalization_params': normalization_params,
    'dataset_properties': {
        'num_cases': len(train_image_files),
        'modalities': dataset_properties['modalities'],
        'median_spacing': target_spacing
    }
}

with open(results_dir / 'nnunet_fingerprint.json', 'w') as f:
    json.dump(fingerprint, f, indent=2)

# 🔧 Multi-Site MRI Harmonization Configuration

In [None]:
print("🔧 MULTI-SITE HARMONIZATION PIPELINE:")

# Harmonization approach for multi-site MRI data
harmonization_methods = ['n4_bias_correction', 'z_score_harmonization']

print(f"Implementing: {', '.join(harmonization_methods)}")

# Create harmonization reference for z-score standardization
harmonization_reference = {}
for modality in dataset_properties['modalities']:
    if modality in normalization_params:
        harmonization_reference[modality] = {
            'target_mean': 0.0,  # Z-score normalization target
            'target_std': 1.0,
            'clip_percentiles': [
                normalization_params[modality]['clip_lower'],
                normalization_params[modality]['clip_upper']
            ]
        }

# Save harmonization parameters
harmonization_config = {
    'methods': harmonization_methods,
    'harmonization_reference': harmonization_reference,
    'n4_parameters': {
        'max_iterations': [50, 50, 30, 20],
        'convergence_threshold': 1e-6,
        'bspline_fitting_distance': 300,
        'shrink_factor': 3
    }
}

with open(results_dir / 'harmonization_config.json', 'w') as f:
    json.dump(harmonization_config, f, indent=2)

print("✅ Multi-site harmonization parameters configured")

# 🚀 nnU-Net v2 Preprocessing Pipeline

In [None]:
print("🚀 nnU-Net v2 PREPROCESSING:")
print(f"Processing {len(train_image_files)} volumes with harmonization pipeline")
print("=" * 60)

processing_stats = []

for idx, (img_path, label_path) in enumerate(tqdm(zip(train_image_files, train_label_files),
                                                  desc="nnU-Net v2 preprocessing")):

    case_id = img_path.stem.split('.')[0]

    try:
        # STEP 1: LOAD VOLUMES AND EXTRACT METADATA
        img_nifti = nib.load(img_path)
        label_nifti = nib.load(label_path)

        img_data = img_nifti.get_fdata().astype(np.float32)
        label_data = label_nifti.get_fdata().astype(np.uint8)
        original_spacing = img_nifti.header.get_zooms()[:3]
        original_affine = img_nifti.affine

        # STEP 2: BRAIN EXTRACTION
        # Use 2% threshold on all non-zero voxels
        all_voxels = img_data[img_data != 0]
        low_thresh = np.percentile(all_voxels, 2.0)

        # Create initial mask from FLAIR (modality 0)
        brain_mask = img_data[..., 0] > low_thresh

        # Apply dilation radius (5x5x5)
        brain_mask = ndimage.binary_dilation(brain_mask, structure=np.ones((5, 5, 5)))

        # Fill holes
        brain_mask = ndimage.binary_fill_holes(brain_mask)

        # Keep largest connected component
        labeled_mask, num_labels = ndimage.label(brain_mask)
        if num_labels > 1:
            sizes = ndimage.sum(brain_mask, labeled_mask, range(1, num_labels + 1))
            largest_label = np.argmax(sizes) + 1
            brain_mask = labeled_mask == largest_label

        # Calculate coverage using bounding box method
        bbox = np.argwhere(brain_mask)
        if len(bbox) > 0:
            mins, maxs = bbox.min(axis=0), bbox.max(axis=0) + 1
            bbox_volume = np.prod(maxs - mins)
            brain_voxels_in_bbox = np.sum(brain_mask[mins[0]:maxs[0], mins[1]:maxs[1], mins[2]:maxs[2]])
            brain_coverage = (brain_voxels_in_bbox / bbox_volume) * 100
        else:
            brain_coverage = 0.0

        # STEP 3: N4 BIAS FIELD CORRECTION
        img_corrected = np.zeros_like(img_data)

        # Ensure spacing is a Python list of floats
        spacing = list(map(float, img_nifti.header.get_zooms()[:3]))

        for modality_idx in range(img_data.shape[-1]):
            mod_data = img_data[..., modality_idx]

            # Convert to SimpleITK images
            sitk_img_mod = sitk.GetImageFromArray(mod_data.astype(np.float32))
            sitk_mask_mod = sitk.GetImageFromArray(brain_mask.astype(np.uint8))

            # Preserve original NIfTI spacing/orientation
            sitk_img_mod.SetSpacing(spacing)
            sitk_mask_mod.SetSpacing(spacing)

            # Shrink for computational efficiency
            shrink_factor = 4
            img_shrunk = sitk.Shrink(sitk_img_mod, [shrink_factor] * 3)
            mask_shrunk = sitk.Shrink(sitk_mask_mod, [shrink_factor] * 3)

            # N4 correction
            corrector = sitk.N4BiasFieldCorrectionImageFilter()
            corrector.SetMaximumNumberOfIterations([50, 50, 30, 20])
            corrector.SetConvergenceThreshold(1e-6)

            try:
                corrected_shrunk = corrector.Execute(img_shrunk, mask_shrunk)

                # Get bias field and apply to full resolution
                log_bias_field = corrector.GetLogBiasFieldAsImage(sitk_img_mod)
                bias_field = np.exp(sitk.GetArrayFromImage(log_bias_field))

                corrected_fullres = mod_data / (bias_field + 1e-8)
                img_corrected[..., modality_idx] = corrected_fullres.astype(np.float32)

            except Exception as e:
                print(f"    ⚠️ N4 failed for modality {modality_idx}: {e}")
                img_corrected[..., modality_idx] = mod_data

        # STEP 4: nnU-Net INTENSITY NORMALIZATION
        img_normalized = np.zeros_like(img_corrected)

        for modality_idx, modality in enumerate(dataset_properties['modalities']):
            mod_data = img_corrected[..., modality_idx]

            if modality not in normalization_params:
                img_normalized[..., modality_idx] = mod_data * brain_mask
                continue

            # nnU-Net style percentile clipping on full image
            clip_lower = normalization_params[modality]['clip_lower']
            clip_upper = normalization_params[modality]['clip_upper']

            mod_data_clipped = np.clip(mod_data, clip_lower, clip_upper)

            # Z-score normalization using global parameters
            brain_voxels_clipped = mod_data_clipped[brain_mask]
            if len(brain_voxels_clipped) > 0:
                mean_val = normalization_params[modality]['mean_intensity']
                std_val = normalization_params[modality]['std_intensity']

                if std_val > 0:
                    # Normalize full image, then apply brain mask
                    normalized_full = (mod_data_clipped - mean_val) / std_val
                    img_normalized[..., modality_idx] = normalized_full * brain_mask
                else:
                    img_normalized[..., modality_idx] = (mod_data_clipped - mean_val) * brain_mask
            else:
                img_normalized[..., modality_idx] = mod_data_clipped * brain_mask

        # STEP 5: SPATIAL RESAMPLING TO nnU-Net TARGET
        zoom_factors = [orig / target for orig, target in zip(original_spacing, target_spacing)]

        # Calculate new shape after resampling
        new_shape = [int(dim * zoom) for dim, zoom in zip(img_normalized.shape[:3], zoom_factors)]

        # Resample images with cubic interpolation
        img_resampled = np.zeros((*new_shape, img_normalized.shape[-1]), dtype=np.float32)
        for modality_idx in range(img_normalized.shape[-1]):
            img_resampled[..., modality_idx] = ndimage.zoom(
                img_normalized[..., modality_idx], zoom_factors, order=3
            )

        # Resample labels with nearest neighbor
        label_resampled = ndimage.zoom(label_data, zoom_factors, order=0).astype(np.uint8)

        # Resample brain mask
        brain_mask_resampled = ndimage.zoom(
            brain_mask.astype(np.uint8), zoom_factors, order=0
        ).astype(bool)

        # Recalculate bounding box for resampled brain mask
        bbox_resampled = np.argwhere(brain_mask_resampled)
        if len(bbox_resampled) > 0:
            mins_resampled, maxs_resampled = bbox_resampled.min(axis=0), bbox_resampled.max(axis=0) + 1
        else:
            mins_resampled, maxs_resampled = [0, 0, 0], brain_mask_resampled.shape

        # STEP 6: QUALITY CONTROL AND VALIDATION
        # Handle NaN/Inf values
        if np.any(np.isnan(img_resampled)) or np.any(np.isinf(img_resampled)):
            print("    ⚠️ Found NaN/Inf values - applying correction...")
            img_resampled = np.nan_to_num(img_resampled, nan=0.0, posinf=0.0, neginf=0.0)

        # Validate label integrity
        valid_labels = [0, 1, 2, 3]  # Background, Edema, Non-enhancing, Enhancing
        invalid_voxels = ~np.isin(label_resampled, valid_labels)
        if np.any(invalid_voxels):
            print(f"    ⚠️ Found {np.sum(invalid_voxels)} invalid label voxels - setting to background")
            label_resampled = np.where(np.isin(label_resampled, valid_labels), label_resampled, 0)

        # Check tumor preservation
        tumor_before = np.sum(label_data > 0)
        tumor_after = np.sum(label_resampled > 0)
        preservation_ratio = tumor_after / (tumor_before + 1e-8)
        print(f"    📊 Tumor preservation: {preservation_ratio:.3f} ({tumor_before} → {tumor_after} voxels)")

        # STEP 7: SAVE nnU-Net FORMAT AND METADATA
        # Create nnU-Net file paths
        output_img_path = nnunet_dir / 'imagesTr' / f"{case_id}_0000.nii.gz"
        output_label_path = nnunet_dir / 'labelsTr' / f"{case_id}.nii.gz"

        # Update affine matrix for new spacing
        new_affine = original_affine.copy()
        new_affine[0, 0] = target_spacing[0] if new_affine[0, 0] > 0 else -target_spacing[0]
        new_affine[1, 1] = target_spacing[1] if new_affine[1, 1] > 0 else -target_spacing[1]
        new_affine[2, 2] = target_spacing[2] if new_affine[2, 2] > 0 else -target_spacing[2]

        # Save NIfTI files
        img_nifti_out = nib.Nifti1Image(img_resampled.astype(np.float32), new_affine)
        label_nifti_out = nib.Nifti1Image(label_resampled.astype(np.uint8), new_affine)

        nib.save(img_nifti_out, output_img_path)
        nib.save(label_nifti_out, output_label_path)

        # Save preprocessing metadata
        np.savez_compressed(
            processed_dir / f"{case_id}_preprocessed.npz",
            image=img_resampled.astype(np.float32),
            label=label_resampled.astype(np.uint8),
            brain_mask_full=brain_mask_resampled,
            brain_mask_roi=brain_mask_resampled[mins_resampled[0]:maxs_resampled[0],
                           mins_resampled[1]:maxs_resampled[1],
                           mins_resampled[2]:maxs_resampled[2]],
            brain_mask_bbox_mins=mins_resampled,
            brain_mask_bbox_maxs=maxs_resampled,
            original_spacing=original_spacing,
            target_spacing=target_spacing,
            original_shape=img_data.shape[:3],
            final_shape=img_resampled.shape[:3],
            normalization_applied=True,
            harmonization_applied=True,
            bias_correction_applied=True,
            histogram_matching_applied=False
        )

        # Record processing statistics
        processing_stats.append({
            'case_id': case_id,
            'success': True,
            'original_shape': img_data.shape[:3],
            'final_shape': img_resampled.shape[:3],
            'original_spacing': original_spacing,
            'final_spacing': target_spacing,
            'tumor_voxels_before': int(tumor_before),
            'tumor_voxels_after': int(tumor_after),
            'tumor_preservation_ratio': float(preservation_ratio),
            'brain_mask_coverage': float(brain_coverage),
            'intensity_range_per_modality': [
                [float(img_resampled[..., i].min()), float(img_resampled[..., i].max())]
                for i in range(img_resampled.shape[-1])
            ]
        })

        print(f"  ✅ {case_id} processing completed successfully!")

        # Progress update
        if (idx + 1) % 2 == 0 or (idx + 1) == len(train_image_files):
            successful = sum(1 for s in processing_stats if s['success'])
            print(f"\n📊 Progress: {idx + 1}/{len(train_image_files)} cases processed ({successful} successful)")
    except Exception as e:
        print(f"  ❌ Error processing {case_id}: {e}")
        processing_stats.append({
            'case_id': case_id,
            'success': False,
            'error': str(e)
        })

# Save comprehensive processing statistics
processing_df = pd.DataFrame(processing_stats)
processing_df.to_csv(results_dir / 'nnunet_preprocessing_stats.csv', index=False)

successful_cases = processing_df['success'].sum()
failed_cases = len(processing_df) - successful_cases

print("\n" + "=" * 60)
print("🎯 nnU-Net v2 PREPROCESSING COMPLETE")
print("=" * 60)
print(f"✅ Successfully processed: {successful_cases}/{len(processing_df)} volumes")
print(f"❌ Failed cases: {failed_cases}")
print(f"📊 Success rate: {successful_cases / len(processing_df) * 100:.1f}%")

if successful_cases > 0:
    successful_stats = processing_df[processing_df['success']].copy()
    avg_preservation = successful_stats['tumor_preservation_ratio'].mean()
    avg_brain_coverage = successful_stats['brain_mask_coverage'].mean()

    print(f"📊 Average tumor preservation: {avg_preservation:.3f}")
    print(f"🧠 Average brain mask coverage: {avg_brain_coverage:.1f}%")

print(f"\n📁 Outputs saved to:")
print(f"  • nnU-Net format: {nnunet_dir}")
print(f"  • Preprocessed data: {processed_dir}")
print(f"  • Statistics: {results_dir}")
print("\n🚀 Ready for nnU-Net training!")

# 📜 nnU-Net Dataset JSON & Cross-Validation Setup

In [None]:
from sklearn.model_selection import KFold
import json

# Convert to JSON-safe format
nnunet_dataset_json = {
    "channel_names": {
        "0": "FLAIR",
        "1": "T1w",
        "2": "T1Gd",
        "3": "T2w"
    },
    "labels": {
        "background": 0,
        "edema": 1,
        "non_enhancing_tumor": 2,
        "enhancing_tumor": 3
    },
    "regions_class_order": [1, 2, 3],
    "numTraining": int(successful_cases),
    "file_ending": ".nii.gz",
    "overwrite_image_reader_writer": "NibabelIOWithReorient",
    "nnUNet_version": "2.0",
    "dataset_name": "Dataset001_BrainTumor",
    "description": "Brain tumor segmentation with nnU-Net v2 preprocessing",
    "reference": "BraTS challenge targeting WT Dice ≥ 90%, BraTS Avg ≥ 80%",
    "tensorImageSize": "4D",
    "training": [{"image": f"./imagesTr/{str(case['case_id'])}_0000.nii.gz",
                  "label": f"./labelsTr/{str(case['case_id'])}.nii.gz"}
                 for case in processing_stats if case['success']]
}

with open(nnunet_dir / 'dataset.json', 'w') as f:
    json.dump(nnunet_dataset_json, f, indent=2)

# Extract successful case IDs
all_cases = [str(stats['case_id']) for stats in processing_stats if stats['success']]

# Create splits if more than 1 case
splits = []
if len(all_cases) > 1:
    kfold = KFold(n_splits=min(5, len(all_cases)), shuffle=True, random_state=42)
    for train_idx, val_idx in kfold.split(all_cases):
        train_cases = [all_cases[i] for i in train_idx]
        val_cases = [all_cases[i] for i in val_idx]
        splits.append({
            "train": train_cases,
            "val": val_cases
        })
else:
    splits = [{"train": all_cases, "val": all_cases}]

with open(nnunet_dir / 'splits_final.json', 'w') as f:
    json.dump(splits, f, indent=2)

# Print result summary
print(f"✅ nnU-Net dataset configuration complete")
print(f"📁 Training images: {len(list((nnunet_dir / 'imagesTr').glob('*.nii.gz')))} files")
print(f"📁 Training labels: {len(list((nnunet_dir / 'labelsTr').glob('*.nii.gz')))} files")
print(f"📁 Cross-validation: {len(splits)} folds configured")

# 🩺 Clinical Volume Metrics & Image Quality Analysis

In [None]:
# Load preprocessed data for clinical analysis
metadata_list = []

for img_path, label_path in zip(train_image_files, train_label_files):
    case_id = img_path.stem.split('.')[0]

    # Load original data for clinical metrics
    img_nifti = nib.load(img_path)
    label_nifti = nib.load(label_path)

    img_data = img_nifti.get_fdata()
    label_data = label_nifti.get_fdata()
    spacing = img_nifti.header.get_zooms()[:3]

    # Calculate BraTS standard metrics
    voxel_volume_cm3 = np.prod(spacing) / 1000

    # BraTS tumor regions
    whole_tumor = label_data > 0
    tumor_core = (label_data == 2) | (label_data == 3)
    enhancing_tumor = label_data == 3

    wt_volume = np.sum(whole_tumor) * voxel_volume_cm3
    tc_volume = np.sum(tumor_core) * voxel_volume_cm3
    et_volume = np.sum(enhancing_tumor) * voxel_volume_cm3

    # Image quality metrics
    modality_stats = {}
    for mod_idx, modality in enumerate(['FLAIR', 'T1w', 'T1Gd', 'T2w']):
        mod_data = img_data[..., mod_idx]
        brain_mask = mod_data > np.percentile(mod_data[mod_data > 0], 5)
        brain_voxels = mod_data[brain_mask]

        if len(brain_voxels) > 0:
            modality_stats[f'{modality}_snr'] = np.mean(brain_voxels) / (np.std(brain_voxels) + 1e-8)
            modality_stats[f'{modality}_mean'] = np.mean(brain_voxels)
            modality_stats[f'{modality}_std'] = np.std(brain_voxels)
        else:
            modality_stats[f'{modality}_snr'] = 0
            modality_stats[f'{modality}_mean'] = 0
            modality_stats[f'{modality}_std'] = 0

    metadata_list.append({
        'case_id': case_id,
        'wt_volume_cm3': wt_volume,
        'tc_volume_cm3': tc_volume,
        'et_volume_cm3': et_volume,
        'tumor_present': wt_volume > 0,
        'original_spacing': spacing,
        **modality_stats
    })

# Create clinical metadata DataFrame
clinical_df = pd.DataFrame(metadata_list)
clinical_df.to_csv(results_dir / 'clinical_analysis.csv', index=False)

# Clinical visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# BraTS volume distributions
volume_cols = ['wt_volume_cm3', 'tc_volume_cm3', 'et_volume_cm3']
colors = ['royalblue', 'crimson', 'forestgreen']
titles = ['Whole Tumor (WT)', 'Tumor Core (TC)', 'Enhancing Tumor (ET)']

for idx, (col, color, title) in enumerate(zip(volume_cols, colors, titles)):
    ax = axes[0, idx]
    data_to_plot = clinical_df[clinical_df[col] > 0][col]
    if len(data_to_plot) > 0:
        data_to_plot.hist(bins=10, alpha=0.7, ax=ax, color=color, edgecolor='black')
    ax.set_title(f'{title} Volume Distribution')
    ax.set_xlabel('Volume (cm³)')
    ax.set_ylabel('Frequency')
    ax.grid(True, alpha=0.3)

# Image quality metrics
quality_cols = ['FLAIR_snr', 'T1w_snr', 'T1Gd_snr']
for idx, col in enumerate(quality_cols):
    ax = axes[1, idx]
    clinical_df[col].hist(bins=10, alpha=0.7, ax=ax, color='orange', edgecolor='black')
    ax.set_title(f'{col.replace("_", " ").upper()}')
    ax.set_xlabel('Signal-to-Noise Ratio')
    ax.set_ylabel('Frequency')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(visualization_dir / 'clinical_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

# Handle single-case stats display
if len(clinical_df) == 1:
    wt_std = tc_std = et_std = "N/A"
else:
    wt_std = f"{clinical_df['wt_volume_cm3'].std():.1f}"
    tc_std = f"{clinical_df['tc_volume_cm3'].std():.1f}"
    et_std = f"{clinical_df['et_volume_cm3'].std():.1f}"

print("📊 Clinical analysis complete")
print(f"• Cases with tumor: {clinical_df['tumor_present'].sum()}/{len(clinical_df)}")
print(f"• Mean WT volume: {clinical_df['wt_volume_cm3'].mean():.1f} ± {wt_std} cm³")
print(f"• Mean TC volume: {clinical_df['tc_volume_cm3'].mean():.1f} ± {tc_std} cm³")
print(f"• Mean ET volume: {clinical_df['et_volume_cm3'].mean():.1f} ± {et_std} cm³")

# 🔍 Preprocessing Validation & Visual Quality Control

In [None]:
# Load sample preprocessed data for validation
sample_files = list(processed_dir.glob('*_preprocessed.npz'))

if sample_files:
    print("🔍 PREPROCESSING VALIDATION:")

    # Validate first sample
    sample_data = np.load(sample_files[0])
    print(f"Sample case: {sample_files[0].stem}")
    print(f"Image shape: {sample_data['image'].shape}")
    print(f"Label shape: {sample_data['label'].shape}")
    print(f"Target spacing: {sample_data['target_spacing']}")
    print(f"Unique labels: {np.unique(sample_data['label'])}")
    print(f"Image intensity range per modality:")
    for i in range(sample_data['image'].shape[-1]):
        mod_min = sample_data['image'][..., i].min()
        mod_max = sample_data['image'][..., i].max()
        mod_mean = np.mean(sample_data['image'][..., i][sample_data['image'][..., i] != 0])
        print(f"  Modality {i}: [{mod_min:.3f}, {mod_max:.3f}], mean: {mod_mean:.3f}")

    # ROI brain mask coverage
    if 'brain_mask_roi' in sample_data:
        roi_mask = sample_data['brain_mask_roi']
        roi_coverage = np.sum(roi_mask) / np.prod(roi_mask.shape) * 100
        print(f"Brain mask coverage (ROI): {roi_coverage:.1f}%")
    else:
        print("⚠️ No ROI mask found in file")

    print(f"Normalization applied: {sample_data['normalization_applied']}")
    print(f"Harmonization applied: {sample_data['harmonization_applied']}")

# Visualization
fig, axes = plt.subplots(3, 4, figsize=(20, 15))

# Load first three cases for comprehensive validation
vis_cases = sample_files[:min(3, len(sample_files))]

for row, case_file in enumerate(vis_cases):
    data = np.load(case_file)
    image = data['image']
    label = data['label']

    # Use full brain mask
    brain_mask_full = data['brain_mask_full'] if 'brain_mask_full' in data else None

    mid_slice = image.shape[2] // 2
    modality_names = ['FLAIR', 'T1w', 'T1Gd', 'T2w']

    for col, modality in enumerate(modality_names):
        ax = axes[row, col]

        # Display normalized image
        ax.imshow(image[:, :, mid_slice, col], cmap='gray', vmin=-3, vmax=3)

        # Overlay segmentation mask
        if np.any(label[:, :, mid_slice] > 0):
            overlay = np.zeros((*label[:, :, mid_slice].shape, 4))
            overlay[label[:, :, mid_slice] == 1] = [0, 1, 0, 0.4]  # Edema - green
            overlay[label[:, :, mid_slice] == 2] = [1, 1, 0, 0.4]  # Non-enhancing - yellow
            overlay[label[:, :, mid_slice] == 3] = [1, 0, 0, 0.4]  # Enhancing - red
            ax.imshow(overlay)

        # Use full brain mask for contour
        if brain_mask_full is not None:
            mask_slice = brain_mask_full[:, :, mid_slice]
            ax.contour(mask_slice, colors='cyan', linewidths=0.5)

        ax.set_title(f'{case_file.stem.split("_")[0]} - {modality}')
        ax.axis('off')

plt.suptitle('nnU-Net Preprocessed Volumes - Quality Control', fontsize=16)
plt.tight_layout()
plt.savefig(visualization_dir / 'preprocessing_validation.png', dpi=300, bbox_inches='tight')
plt.show()

# 📑 Final Preprocessing Report

In [None]:
# Generate comprehensive preprocessing report
final_report = {
    'preprocessing_summary': {
        'total_cases_input': int(len(train_image_files)),
        'successfully_processed': int(successful_cases),
        'success_rate_percent': round(float(successful_cases / len(train_image_files) * 100), 1),
        'target_metrics': 'WT Dice ≥ 90%, BraTS Avg ≥ 80%'
    },
    'nnunet_configuration': {
        'dataset_name': 'Dataset001_BrainTumor',
        'target_spacing_mm': [float(x) for x in target_spacing],
        'modalities': ['FLAIR', 'T1w', 'T1Gd', 'T2w'],
        'labels': ['background', 'edema', 'non_enhancing_tumor', 'enhancing_tumor'],
        'cross_validation_folds': int(len(splits))
    },
    'clinical_characteristics': {
        'cases_with_tumor': int(clinical_df['tumor_present'].sum()),
        'mean_wt_volume_cm3': round(float(clinical_df['wt_volume_cm3'].mean()), 2),
        'mean_tc_volume_cm3': round(float(clinical_df['tc_volume_cm3'].mean()), 2),
        'mean_et_volume_cm3': round(float(clinical_df['et_volume_cm3'].mean()), 2)
    },
    'preprocessing_pipeline': [
        'Brain extraction using 2% threshold + 5x5x5 dilation',
        'N4 bias field correction for multi-site harmonization',
        'nnU-Net v2 intensity normalization with global parameters',
        'Spacing-based resampling with cubic interpolation',
        'Quality control with NaN/Inf handling and label validation'
    ],
    'quality_metrics': {
        'mean_brain_mask_coverage_percent': round(float(np.mean([
            s['brain_mask_coverage'] for s in processing_stats if s['success']
        ])), 1),
        'mean_tumor_preservation_ratio': round(float(np.mean([
            s['tumor_preservation_ratio'] for s in processing_stats if s['success']
        ])), 3)
    }
}

# Save report
with open(results_dir / 'final_preprocessing_report.json', 'w') as f:
    json.dump(final_report, f, indent=2)

# Handle single-case stats
wt_std = clinical_df['wt_volume_cm3'].std() if len(clinical_df) > 1 else 0
tc_std = clinical_df['tc_volume_cm3'].std() if len(clinical_df) > 1 else 0
et_std = clinical_df['et_volume_cm3'].std() if len(clinical_df) > 1 else 0

# Final summary
print("🎯 nnU-Net v2 PREPROCESSING COMPLETE")
print("=" * 60)
print(f"✅ Successfully processed: {successful_cases}/{len(train_image_files)} volumes")
print(f"✅ Brain mask coverage: {np.mean([s['brain_mask_coverage'] for s in processing_stats if s['success']]):.1f}%")
print(f"✅ Tumor preservation: {np.mean([s['tumor_preservation_ratio'] for s in processing_stats if s['success']]):.3f}")

print(f"\n📊 CLINICAL METRICS:")
print(f"• Whole Tumor: {clinical_df['wt_volume_cm3'].mean():.1f} ± {wt_std:.1f} cm³")
print(f"• Tumor Core: {clinical_df['tc_volume_cm3'].mean():.1f} ± {tc_std:.1f} cm³")
print(f"• Enhancing Tumor: {clinical_df['et_volume_cm3'].mean():.1f} ± {et_std:.1f} cm³")

print(f"\n🚀 READY FOR: BRAIN_model_training.ipynb")

# Memory cleanup
import gc
gc.collect()
print("✅ Memory cleaned - preprocessing pipeline complete")

In [None]:
# pip freeze