# Medical Image Preprocessing Pipeline

This notebook implements a comprehensive preprocessing pipeline for 3D medical images, converting them into normalized 2D patches suitable for deep learning segmentation models.

## Pipeline Overview
1. **Configuration & Setup** - Environment validation and parameter configuration
2. **3D to 2D Conversion** - Extract valid slices from 3D volumes
3. **Patch Extraction** - Generate fixed-size patches with quality filtering
4. **Data Normalization** - Intensity normalization for consistent model input

## Requirements
- PyTorch, MONAI, NiBabel
- FSL (for geometry operations)
- Input: 3D NIfTI files organized by side (left/right)
- Output: Normalized 32x32 patches ready for segmentation

In [None]:
# Core imports
import os
import sys
from pathlib import Path
import logging
from typing import Dict, List, Tuple

# Scientific computing
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from tqdm.auto import tqdm

# Import pipeline functions
from functions import (
    extract_identifier,
    process_3d_to_2d,
    extract_defined_patches,
    save_patches_to_nifti,
    reconstruct_3d_volume,
    normalize_all_patches
)

# Configure matplotlib for inline display
%matplotlib inline

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("Medical Image Preprocessing Pipeline - Initialized")
print(f"Working directory: {os.getcwd()}")

## 1.1 Configuration & Setup

Configure all pipeline parameters in this single cell. Modify paths and parameters according to your data structure.

In [None]:
# =============================================================================
# CONFIGURATION PARAMETERS - Modify these according to your setup
# =============================================================================

# Data paths
DATA_ROOT = Path("data")  # Root directory for all data
RAW_DATA_PATH = DATA_ROOT / "raw"  # Contains left/ and right/ subdirectories
PROCESSED_DATA_PATH = DATA_ROOT / "processed"  # Output directory

# Input data structure
INPUT_SIDES = ['left', 'right']  # Subdirectories in raw data
IMAGE_EXTENSION = '.nii.gz'  # Expected file extension

# Sample patient identifiers (replace with your actual data)
SAMPLE_PATIENTS = ['patient001', 'patient002']  # Example patient IDs

# Processing parameters
PATCH_SIZE = (32, 32)  # Patch dimensions (height, width)
BLACK_THRESHOLD = 0.95  # Maximum fraction of black pixels allowed in patch
VOLUME_SHAPE = (176, 240, 165)  # Expected 3D volume dimensions (H, W, D)

# Output structure
OUTPUT_DIRS = {
    'patches': PROCESSED_DATA_PATH / 'patches',
    'patches_normalized': PROCESSED_DATA_PATH / 'patches_normalized',
    'reconstructed': PROCESSED_DATA_PATH / 'reconstructed_validation'
}

# Create output directories
for path in OUTPUT_DIRS.values():
    path.mkdir(parents=True, exist_ok=True)

print("Configuration loaded successfully:")
print(f"  Raw data path: {RAW_DATA_PATH}")
print(f"  Processed data path: {PROCESSED_DATA_PATH}")
print(f"  Patch size: {PATCH_SIZE}")
print(f"  Volume shape: {VOLUME_SHAPE}")
print(f"  Expected sides: {INPUT_SIDES}")

### Environment Validation

Verify data structure and availability before processing.

In [None]:
# Validate data structure
def validate_data_structure() -> Dict[str, List[str]]:
    """Validate input data structure and return available files."""
    available_files = {}
    
    if not RAW_DATA_PATH.exists():
        raise FileNotFoundError(f"Raw data directory not found: {RAW_DATA_PATH}")
    
    for side in INPUT_SIDES:
        side_path = RAW_DATA_PATH / side
        if not side_path.exists():
            logger.warning(f"Side directory not found: {side_path}")
            available_files[side] = []
            continue
            
        # Find all NIfTI files
        files = sorted([f.name for f in side_path.glob(f'*{IMAGE_EXTENSION}')])
        available_files[side] = files
        
        print(f"Found {len(files)} files in {side}/ directory:")
        for file in files[:5]:  # Show first 5 files
            patient_id = extract_identifier(file)
            print(f"  {file} -> Patient ID: {patient_id}")
        if len(files) > 5:
            print(f"  ... and {len(files) - 5} more files")
    
    total_files = sum(len(files) for files in available_files.values())
    print(f"\nTotal files found: {total_files}")
    
    return available_files

# Run validation
available_files = validate_data_structure()

## 1.2 3D to 2D Conversion

Convert 3D medical volumes into 2D slices, filtering out empty/black slices.

In [None]:
# Load 3D volumes
def load_3d_volumes(available_files: Dict[str, List[str]]) -> Dict[str, nib.Nifti1Image]:
    """Load all 3D volumes into memory."""
    volumes_3d = {}
    
    for side, files in available_files.items():
        side_path = RAW_DATA_PATH / side
        
        for filename in files:
            file_path = side_path / filename
            patient_id = extract_identifier(filename)
            key = f"{patient_id}-{side}"
            
            try:
                volume = nib.load(str(file_path))
                volumes_3d[key] = volume
                logger.info(f"Loaded {key}: {volume.shape}")
            except Exception as e:
                logger.error(f"Failed to load {file_path}: {e}")
                continue
    
    return volumes_3d

# Load all volumes
print("Loading 3D volumes...")
volumes_3d = load_3d_volumes(available_files)
print(f"Successfully loaded {len(volumes_3d)} volumes")

# Display volume information
for key, volume in volumes_3d.items():
    print(f"  {key}: {volume.shape} - {volume.get_data_dtype()}")

In [None]:
# Convert 3D volumes to 2D slices
print("Converting 3D volumes to 2D slices...")
slices_2d = process_3d_to_2d(volumes_3d)

# Summary statistics
total_valid_slices = sum(len(data['image']) for data in slices_2d.values())
print(f"\nConversion complete:")
print(f"  Total patients: {len(slices_2d)}")
print(f"  Total valid slices: {total_valid_slices}")

# Detailed breakdown
for patient_key, data in slices_2d.items():
    total_slices = len(data['slices']) if data['slices'] else 0
    valid_slices = len(data['image'])
    print(f"  {patient_key}: {valid_slices}/{total_slices} valid slices")

### Quality Assessment Visualization

Visualize sample slices to verify conversion quality.

In [None]:
# Visualize sample slices
def visualize_sample_slices(slices_2d: Dict, num_samples: int = 2) -> None:
    """Visualize sample slices from each patient."""
    fig, axes = plt.subplots(2, num_samples, figsize=(15, 8))
    if num_samples == 1:
        axes = axes.reshape(-1, 1)
    
    patient_keys = list(slices_2d.keys())[:num_samples]
    
    for i, patient_key in enumerate(patient_keys):
        patient_data = slices_2d[patient_key]
        
        if len(patient_data['image']) == 0:
            axes[0, i].text(0.5, 0.5, 'No valid slices', ha='center', va='center')
            axes[1, i].text(0.5, 0.5, 'No valid slices', ha='center', va='center')
            continue
        
        # Show first and middle slice
        first_slice = patient_data['image'][0].squeeze()
        mid_idx = len(patient_data['image']) // 2
        mid_slice = patient_data['image'][mid_idx].squeeze()
        
        # Plot first slice
        axes[0, i].imshow(first_slice, cmap='gray')
        axes[0, i].set_title(f'{patient_key}\nFirst slice')
        axes[0, i].axis('off')
        
        # Plot middle slice
        axes[1, i].imshow(mid_slice, cmap='gray')
        axes[1, i].set_title(f'Middle slice ({mid_idx+1}/{len(patient_data["image"])})')
        axes[1, i].axis('off')
    
    plt.suptitle('Sample 2D Slices After Conversion', fontsize=16)
    plt.tight_layout()
    plt.show()

if slices_2d:
    visualize_sample_slices(slices_2d, min(2, len(slices_2d)))
else:
    print("No valid slices available for visualization")

## 1.3 Patch Extraction

Extract fixed-size patches from 2D slices with quality filtering.

In [None]:
# Extract patches from all slices
def extract_all_patches(slices_2d: Dict, patch_size: Tuple[int, int], 
                       black_threshold: float) -> Tuple[Dict, Dict]:
    """Extract patches from all 2D slices."""
    patches_data = {}
    patches_coordinates = {}
    
    total_patches = 0
    
    for patient_key, slice_data in tqdm(slices_2d.items(), desc="Extracting patches"):
        patient_patches = []
        patient_coordinates = []
        
        slice_coordinates = slice_data['slices']
        
        for idx, image_slice in enumerate(slice_data['image']):
            slice_num = slice_coordinates[idx] if idx < len(slice_coordinates) else idx
            
            # Extract patches from this slice
            patches, coordinates = extract_defined_patches(
                image_slice, patch_size, black_threshold, slice_num
            )
            
            patient_patches.extend(patches)
            patient_coordinates.extend(coordinates)
        
        patches_data[patient_key] = {'image': patient_patches}
        patches_coordinates[patient_key] = patient_coordinates
        
        patient_patch_count = len(patient_patches)
        total_patches += patient_patch_count
        logger.info(f"{patient_key}: {patient_patch_count} patches extracted")
    
    return patches_data, patches_coordinates, total_patches

# Extract patches
print("Extracting patches...")
patches_data, patches_coordinates, total_patches = extract_all_patches(
    slices_2d, PATCH_SIZE, BLACK_THRESHOLD
)

print(f"\nPatch extraction complete:")
print(f"  Total patches: {total_patches}")
print(f"  Patch size: {PATCH_SIZE}")
print(f"  Black threshold: {BLACK_THRESHOLD}")

In [None]:
# Save patches to NIfTI format
print("Saving patches to NIfTI format...")
output_patches_dir = OUTPUT_DIRS['patches']

save_patches_to_nifti(
    patches_dir2D=patches_data,
    patches_dir2D_coordinates=patches_coordinates,
    base_output_dir=str(output_patches_dir)
)

print(f"Patches saved to: {output_patches_dir}")

# Verify saved structure
def verify_saved_structure(base_path: Path) -> None:
    """Verify the saved patch structure."""
    print("\nSaved directory structure:")
    for side in INPUT_SIDES:
        side_path = base_path / side
        if side_path.exists():
            patients = [d.name for d in side_path.iterdir() if d.is_dir()]
            print(f"  {side}/: {len(patients)} patients")
            
            for patient in patients[:3]:  # Show first 3 patients
                patient_path = side_path / patient
                patch_count = len(list(patient_path.glob('*.nii.gz')))
                print(f"    {patient}: {patch_count} patches")

verify_saved_structure(output_patches_dir)

### Optional: Validation Reconstruction

Reconstruct 3D volumes from patches to validate extraction accuracy.

In [None]:
# Perform validation reconstruction (optional)
def perform_validation_reconstruction(sample_limit: int = 2) -> None:
    """Reconstruct 3D volumes from patches for validation."""
    reconstruction_dir = OUTPUT_DIRS['reconstructed']
    patches_dir = OUTPUT_DIRS['patches']
    
    reconstructed_files = []
    
    # Get sample of patients for reconstruction
    all_patients = set()
    for side in INPUT_SIDES:
        side_path = patches_dir / side
        if side_path.exists():
            patients = [d.name for d in side_path.iterdir() if d.is_dir()]
            all_patients.update(patients)
    
    sample_patients = sorted(list(all_patients))[:sample_limit]
    
    for patient_id in sample_patients:
        for side in INPUT_SIDES:
            # Find reference image
            reference_pattern = f"{patient_id}_{side}*{IMAGE_EXTENSION}"
            reference_files = list((RAW_DATA_PATH / side).glob(reference_pattern))
            
            if not reference_files:
                logger.warning(f"No reference image found for {patient_id}-{side}")
                continue
            
            reference_path = str(reference_files[0])
            
            try:
                output_path = reconstruct_3d_volume(
                    base_patch_dir=str(patches_dir),
                    patient_id_short=patient_id,
                    side=side,
                    output_base_dir=str(reconstruction_dir),
                    reference_image_path=reference_path,
                    volume_shape=VOLUME_SHAPE
                )
                reconstructed_files.append(output_path)
                logger.info(f"Reconstructed: {output_path}")
            except Exception as e:
                logger.error(f"Reconstruction failed for {patient_id}-{side}: {e}")
    
    return reconstructed_files

# Perform validation reconstruction
print("Performing validation reconstruction...")
try:
    reconstructed_files = perform_validation_reconstruction(sample_limit=2)
    print(f"Validation reconstruction complete. Files saved:")
    for file_path in reconstructed_files:
        print(f"  {file_path}")
except Exception as e:
    logger.warning(f"Validation reconstruction failed: {e}")
    print("Skipping validation reconstruction (optional step)")

## 1.4 Data Normalization

Apply intensity normalization to all patches for consistent model input.

In [None]:
# Normalize all patches
def normalize_patient_patches() -> Dict[str, int]:
    """Normalize patches for all patients and sides."""
    patches_dir = OUTPUT_DIRS['patches']
    normalized_dir = OUTPUT_DIRS['patches_normalized']
    
    normalization_stats = {}
    
    for side in INPUT_SIDES:
        side_input_path = patches_dir / side
        side_output_path = normalized_dir / side
        
        if not side_input_path.exists():
            logger.warning(f"Input side directory not found: {side_input_path}")
            continue
        
        patients = [d.name for d in side_input_path.iterdir() if d.is_dir()]
        
        for patient_id in tqdm(patients, desc=f"Normalizing {side} patches"):
            input_folder = side_input_path / patient_id
            output_folder = side_output_path / patient_id
            output_folder.mkdir(parents=True, exist_ok=True)
            
            # Count patches before normalization
            input_patches = len(list(input_folder.glob('*.nii.gz')))
            
            try:
                normalize_all_patches(str(input_folder), str(output_folder))
                
                # Count patches after normalization
                output_patches = len(list(output_folder.glob('*.nii.gz')))
                
                key = f"{patient_id}-{side}"
                normalization_stats[key] = {
                    'input_patches': input_patches,
                    'output_patches': output_patches
                }
                
                logger.info(f"Normalized {patient_id}-{side}: {input_patches} -> {output_patches} patches")
                
            except Exception as e:
                logger.error(f"Normalization failed for {patient_id}-{side}: {e}")
                normalization_stats[f"{patient_id}-{side}"] = {'error': str(e)}
    
    return normalization_stats

# Perform normalization
print("Normalizing patches...")
normalization_stats = normalize_patient_patches()

# Display normalization summary
print("\nNormalization complete:")
total_normalized = 0
failed_normalizations = 0

for key, stats in normalization_stats.items():
    if 'error' in stats:
        print(f"  {key}: FAILED - {stats['error']}")
        failed_normalizations += 1
    else:
        patches_count = stats['output_patches']
        total_normalized += patches_count
        print(f"  {key}: {patches_count} patches normalized")

print(f"\nSummary:")
print(f"  Total normalized patches: {total_normalized}")
print(f"  Failed normalizations: {failed_normalizations}")
print(f"  Output directory: {OUTPUT_DIRS['patches_normalized']}")

## Pipeline Summary

Complete preprocessing pipeline execution summary.

In [None]:
# Generate pipeline summary
def generate_pipeline_summary() -> None:
    """Generate and display complete pipeline summary."""
    print("=" * 60)
    print("MEDICAL IMAGE PREPROCESSING PIPELINE - SUMMARY")
    print("=" * 60)
    
    # Input summary
    total_input_files = sum(len(files) for files in available_files.values())
    print(f"Input Data:")
    print(f"  Total 3D volumes processed: {total_input_files}")
    print(f"  Input directory: {RAW_DATA_PATH}")
    
    # Processing summary
    total_slices = sum(len(data['image']) for data in slices_2d.values())
    print(f"\n2D Conversion:")
    print(f"  Valid 2D slices extracted: {total_slices}")
    
    # Patch extraction summary
    print(f"\nPatch Extraction:")
    print(f"  Total patches extracted: {total_patches}")
    print(f"  Patch size: {PATCH_SIZE}")
    print(f"  Black pixel threshold: {BLACK_THRESHOLD}")
    
    # Normalization summary
    successful_normalizations = sum(1 for stats in normalization_stats.values() if 'error' not in stats)
    total_normalized_patches = sum(stats.get('output_patches', 0) for stats in normalization_stats.values() if 'error' not in stats)
    
    print(f"\nNormalization:")
    print(f"  Successfully normalized datasets: {successful_normalizations}/{len(normalization_stats)}")
    print(f"  Total normalized patches: {total_normalized_patches}")
    
    # Output summary
    print(f"\nOutput Directories:")
    for name, path in OUTPUT_DIRS.items():
        status = "✓" if path.exists() else "✗"
        print(f"  {name}: {status} {path}")
    
    print(f"\nPipeline Status: COMPLETED SUCCESSFULLY")
    print("=" * 60)

# Generate summary
generate_pipeline_summary()

# Save processing metadata
metadata = {
    'pipeline_version': '1.0',
    'patch_size': PATCH_SIZE,
    'black_threshold': BLACK_THRESHOLD,
    'volume_shape': VOLUME_SHAPE,
    'total_input_volumes': len(volumes_3d),
    'total_valid_slices': sum(len(data['image']) for data in slices_2d.values()),
    'total_patches': total_patches,
    'normalization_stats': normalization_stats
}

# Note: In production, you might want to save this metadata to a JSON file
print(f"\nMetadata available in 'metadata' variable for further processing.")