# 07 - Post-processing and Morphological Operations

## Overview
This notebook implements advanced post-processing techniques for cardiac MRI segmentation, including:
- **Morphological Operations**: Opening, closing, erosion, dilation
- **Connected Component Analysis**: Removing isolated pixels and small components
- **Anatomical Constraints**: Enforcing realistic cardiac anatomy
- **Boundary Refinement**: Smoothing and regularizing segmentation boundaries
- **False Positive Removal**: Eliminating artifacts and noise
- **Quality Assessment**: Validating post-processed results

## Learning Objectives
- Understand morphological image processing for medical segmentation
- Implement connected component analysis for noise removal
- Apply anatomical constraints for cardiac structures
- Develop boundary refinement techniques
- Create quality assessment metrics for post-processing

---

In [None]:
# Environment and Package Management
import sys
import os
from pathlib import Path

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Install required packages for Colab
    !pip install -q opencv-python scikit-image scipy matplotlib seaborn
    
    # Mount Google Drive if needed
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Set working directory
    os.chdir('/content/drive/MyDrive/cardiac_segmentation')
else:
    # For local development, ensure packages are available
    try:
        import cv2
        import skimage
        import scipy
    except ImportError as e:
        print(f"Please install missing package: {e}")
        print("Run: pip install opencv-python scikit-image scipy")

# Create directories if they don't exist
Path("outputs/postprocessed").mkdir(parents=True, exist_ok=True)
Path("outputs/morphology").mkdir(parents=True, exist_ok=True)

print("Environment setup complete!")
print(f"Running in: {'Google Colab' if IN_COLAB else 'Local Environment'}")
print(f"Current working directory: {os.getcwd()}")

In [None]:
# Core Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Image Processing
import cv2
from skimage import morphology, measure, filters, segmentation
from skimage.feature import peak_local_maxima
from skimage.segmentation import watershed, clear_border
from scipy import ndimage
from scipy.spatial.distance import cdist
from scipy.optimize import minimize

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

# Visualization
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.figure_factory as ff

# Utilities
from typing import Dict, List, Tuple, Optional, Union, Callable
from dataclasses import dataclass
from abc import ABC, abstractmethod
import json
import pickle
from collections import defaultdict

# Set style for better visualization
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("✅ All imports successful!")

In [None]:
@dataclass
class PostProcessConfig:
    """Configuration for post-processing operations"""
    
    # Morphological operations
    kernel_size: int = 3
    erosion_iterations: int = 1
    dilation_iterations: int = 2
    opening_iterations: int = 1
    closing_iterations: int = 2
    
    # Connected components
    min_component_size: int = 100
    max_component_count: int = 5
    connectivity: int = 8
    
    # Anatomical constraints
    min_heart_area: int = 1000
    max_heart_area: int = 50000
    aspect_ratio_range: Tuple[float, float] = (0.5, 2.0)
    circularity_threshold: float = 0.3
    
    # Boundary refinement
    gaussian_sigma: float = 1.0
    contour_epsilon: float = 0.02
    smooth_iterations: int = 3
    
    # Distance constraints
    max_distance_between_structures: float = 100.0
    min_distance_from_border: int = 10
    
    # Quality thresholds
    confidence_threshold: float = 0.7
    dice_threshold: float = 0.8
    hausdorff_threshold: float = 10.0

@dataclass
class SegmentationResult:
    """Container for segmentation results with metadata"""
    prediction: np.ndarray
    confidence: np.ndarray
    original_shape: Tuple[int, ...]
    processing_time: float
    metadata: Dict
    
@dataclass
class PostProcessingMetrics:
    """Metrics for post-processing evaluation"""
    components_removed: int
    area_change: float
    boundary_smoothness: float
    confidence_improvement: float
    processing_time: float
    quality_score: float

# Default configuration
config = PostProcessConfig()
print("✅ Configuration classes defined!")
print(f"Default config: {config}")

In [None]:
class MorphologicalProcessor:
    """Advanced morphological operations for cardiac segmentation"""
    
    def __init__(self, config: PostProcessConfig):
        self.config = config
        self.kernels = self._create_kernels()
        
    def _create_kernels(self) -> Dict[str, np.ndarray]:
        """Create morphological kernels for different operations"""
        kernels = {}
        
        # Standard circular kernel
        kernels['circular'] = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE, 
            (self.config.kernel_size, self.config.kernel_size)
        )
        
        # Cross-shaped kernel for connectivity
        kernels['cross'] = cv2.getStructuringElement(
            cv2.MORPH_CROSS, 
            (self.config.kernel_size, self.config.kernel_size)
        )
        
        # Rectangular kernel for directional operations
        kernels['rect'] = cv2.getStructuringElement(
            cv2.MORPH_RECT, 
            (self.config.kernel_size, self.config.kernel_size)
        )
        
        # Custom cardiac-specific kernel (elongated)
        cardiac_kernel = np.zeros((7, 5), dtype=np.uint8)
        cardiac_kernel[1:6, 1:4] = 1
        kernels['cardiac'] = cardiac_kernel
        
        return kernels
    
    def erosion(self, mask: np.ndarray, kernel_type: str = 'circular') -> np.ndarray:
        """Apply erosion operation"""
        kernel = self.kernels[kernel_type]
        return cv2.erode(
            mask.astype(np.uint8), 
            kernel, 
            iterations=self.config.erosion_iterations
        )
    
    def dilation(self, mask: np.ndarray, kernel_type: str = 'circular') -> np.ndarray:
        """Apply dilation operation"""
        kernel = self.kernels[kernel_type]
        return cv2.dilate(
            mask.astype(np.uint8), 
            kernel, 
            iterations=self.config.dilation_iterations
        )
    
    def opening(self, mask: np.ndarray, kernel_type: str = 'circular') -> np.ndarray:
        """Apply opening operation (erosion followed by dilation)"""
        kernel = self.kernels[kernel_type]
        return cv2.morphologyEx(
            mask.astype(np.uint8), 
            cv2.MORPH_OPEN, 
            kernel,
            iterations=self.config.opening_iterations
        )
    
    def closing(self, mask: np.ndarray, kernel_type: str = 'circular') -> np.ndarray:
        """Apply closing operation (dilation followed by erosion)"""
        kernel = self.kernels[kernel_type]
        return cv2.morphologyEx(
            mask.astype(np.uint8), 
            cv2.MORPH_CLOSE, 
            kernel,
            iterations=self.config.closing_iterations
        )
    
    def gradient(self, mask: np.ndarray, kernel_type: str = 'circular') -> np.ndarray:
        """Apply morphological gradient (dilation - erosion)"""
        kernel = self.kernels[kernel_type]
        return cv2.morphologyEx(
            mask.astype(np.uint8), 
            cv2.MORPH_GRADIENT, 
            kernel
        )
    
    def tophat(self, mask: np.ndarray, kernel_type: str = 'circular') -> np.ndarray:
        """Apply top hat operation"""
        kernel = self.kernels[kernel_type]
        return cv2.morphologyEx(
            mask.astype(np.uint8), 
            cv2.MORPH_TOPHAT, 
            kernel
        )
    
    def blackhat(self, mask: np.ndarray, kernel_type: str = 'circular') -> np.ndarray:
        """Apply black hat operation"""
        kernel = self.kernels[kernel_type]
        return cv2.morphologyEx(
            mask.astype(np.uint8), 
            cv2.MORPH_BLACKHAT, 
            kernel
        )
    
    def adaptive_morphology(self, mask: np.ndarray, area_threshold: int = 1000) -> np.ndarray:
        """Apply adaptive morphological operations based on region size"""
        # Use different kernel sizes based on connected component sizes
        labels = measure.label(mask)
        processed_mask = np.zeros_like(mask)
        
        for region in measure.regionprops(labels):
            if region.area < area_threshold:
                # Small regions: use smaller kernel
                small_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
                region_mask = (labels == region.label).astype(np.uint8)
                processed_region = cv2.morphologyEx(region_mask, cv2.MORPH_CLOSE, small_kernel)
            else:
                # Large regions: use larger kernel
                large_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
                region_mask = (labels == region.label).astype(np.uint8)
                processed_region = cv2.morphologyEx(region_mask, cv2.MORPH_OPEN, large_kernel)
            
            processed_mask += processed_region
        
        return (processed_mask > 0).astype(np.uint8)
    
    def multi_scale_morphology(self, mask: np.ndarray, scales: List[int] = [3, 5, 7]) -> np.ndarray:
        """Apply morphological operations at multiple scales"""
        results = []
        
        for scale in scales:
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (scale, scale))
            # Apply opening at each scale
            opened = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
            results.append(opened)
        
        # Combine results using majority voting
        combined = np.stack(results, axis=-1)
        return (np.sum(combined, axis=-1) > len(scales) // 2).astype(np.uint8)

# Test the morphological processor
morph_processor = MorphologicalProcessor(config)
print("✅ Morphological Processor initialized!")
print(f"Available kernels: {list(morph_processor.kernels.keys())}")

# Visualize kernels
fig, axes = plt.subplots(1, 4, figsize=(15, 3))
for i, (name, kernel) in enumerate(morph_processor.kernels.items()):
    axes[i].imshow(kernel, cmap='gray')
    axes[i].set_title(f'{name} kernel')
    axes[i].axis('off')
plt.tight_layout()
plt.show()

In [None]:
class ConnectedComponentAnalyzer:
    """Advanced connected component analysis for cardiac segmentation"""
    
    def __init__(self, config: PostProcessConfig):
        self.config = config
        
    def remove_small_components(self, mask: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Remove small connected components"""
        # Label connected components
        labels = measure.label(mask, connectivity=self.config.connectivity)
        properties = measure.regionprops(labels)
        
        # Filter components by size
        filtered_mask = np.zeros_like(mask)
        removed_components = 0
        kept_components = 0
        
        for prop in properties:
            if prop.area >= self.config.min_component_size:
                filtered_mask[labels == prop.label] = 1
                kept_components += 1
            else:
                removed_components += 1
        
        stats = {
            'original_components': len(properties),
            'kept_components': kept_components,
            'removed_components': removed_components,
            'size_threshold': self.config.min_component_size
        }
        
        return filtered_mask.astype(np.uint8), stats
    
    def keep_largest_components(self, mask: np.ndarray, n_components: int = None) -> Tuple[np.ndarray, Dict]:
        """Keep only the N largest connected components"""
        if n_components is None:
            n_components = self.config.max_component_count
            
        # Label and analyze components
        labels = measure.label(mask, connectivity=self.config.connectivity)
        properties = measure.regionprops(labels)
        
        if len(properties) == 0:
            return mask, {'kept_components': 0, 'removed_components': 0}
        
        # Sort by area (largest first)
        properties_sorted = sorted(properties, key=lambda x: x.area, reverse=True)
        
        # Keep only the largest N components
        filtered_mask = np.zeros_like(mask)
        for i, prop in enumerate(properties_sorted[:n_components]):
            filtered_mask[labels == prop.label] = 1
        
        stats = {
            'original_components': len(properties),
            'kept_components': min(n_components, len(properties)),
            'removed_components': max(0, len(properties) - n_components),
            'largest_area': properties_sorted[0].area if properties_sorted else 0
        }
        
        return filtered_mask.astype(np.uint8), stats
    
    def analyze_component_properties(self, mask: np.ndarray) -> pd.DataFrame:
        """Analyze properties of connected components"""
        labels = measure.label(mask, connectivity=self.config.connectivity)
        properties = measure.regionprops(labels)
        
        if not properties:
            return pd.DataFrame()
        
        data = []
        for prop in properties:
            # Basic properties
            area = prop.area
            perimeter = prop.perimeter
            centroid = prop.centroid
            bbox = prop.bbox
            
            # Geometric properties
            circularity = 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0
            aspect_ratio = prop.major_axis_length / prop.minor_axis_length if prop.minor_axis_length > 0 else 0
            solidity = prop.solidity
            extent = prop.extent
            
            # Shape properties
            eccentricity = prop.eccentricity
            orientation = prop.orientation
            
            data.append({
                'label': prop.label,
                'area': area,
                'perimeter': perimeter,
                'centroid_y': centroid[0],
                'centroid_x': centroid[1],
                'bbox_min_row': bbox[0],
                'bbox_min_col': bbox[1],
                'bbox_max_row': bbox[2],
                'bbox_max_col': bbox[3],
                'circularity': circularity,
                'aspect_ratio': aspect_ratio,
                'solidity': solidity,
                'extent': extent,
                'eccentricity': eccentricity,
                'orientation': orientation
            })
        
        return pd.DataFrame(data)
    
    def filter_by_shape(self, mask: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Filter components based on shape criteria"""
        labels = measure.label(mask, connectivity=self.config.connectivity)
        properties = measure.regionprops(labels)
        
        filtered_mask = np.zeros_like(mask)
        kept_count = 0
        removed_count = 0
        
        for prop in properties:
            # Calculate shape metrics
            area = prop.area
            perimeter = prop.perimeter
            circularity = 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0
            aspect_ratio = prop.major_axis_length / prop.minor_axis_length if prop.minor_axis_length > 0 else 0
            
            # Apply filters
            valid_area = self.config.min_heart_area <= area <= self.config.max_heart_area
            valid_aspect_ratio = self.config.aspect_ratio_range[0] <= aspect_ratio <= self.config.aspect_ratio_range[1]
            valid_circularity = circularity >= self.config.circularity_threshold
            
            if valid_area and valid_aspect_ratio and valid_circularity:
                filtered_mask[labels == prop.label] = 1
                kept_count += 1
            else:
                removed_count += 1
        
        stats = {
            'kept_components': kept_count,
            'removed_components': removed_count,
            'shape_criteria': {
                'area_range': (self.config.min_heart_area, self.config.max_heart_area),
                'aspect_ratio_range': self.config.aspect_ratio_range,
                'circularity_threshold': self.config.circularity_threshold
            }
        }
        
        return filtered_mask.astype(np.uint8), stats
    
    def remove_border_components(self, mask: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Remove components touching the image border"""
        # Use skimage's clear_border function
        cleared_mask = clear_border(mask.astype(bool)).astype(np.uint8)
        
        # Calculate statistics
        original_components = len(measure.regionprops(measure.label(mask)))
        remaining_components = len(measure.regionprops(measure.label(cleared_mask)))
        
        stats = {
            'original_components': original_components,
            'remaining_components': remaining_components,
            'removed_components': original_components - remaining_components
        }
        
        return cleared_mask, stats
    
    def watershed_separation(self, mask: np.ndarray) -> np.ndarray:
        """Separate touching components using watershed algorithm"""
        # Distance transform
        distance = ndimage.distance_transform_edt(mask)
        
        # Find local maxima
        local_maxima = peak_local_maxima(distance, min_distance=10, threshold_abs=0.3*distance.max())
        markers = np.zeros_like(mask, dtype=int)
        markers[tuple(local_maxima.T)] = np.arange(1, len(local_maxima) + 1)
        
        # Apply watershed
        labels = watershed(-distance, markers, mask=mask)
        
        return (labels > 0).astype(np.uint8)

# Test the connected component analyzer
cc_analyzer = ConnectedComponentAnalyzer(config)
print("✅ Connected Component Analyzer initialized!")

# Create a test mask with multiple components
test_mask = np.zeros((200, 200), dtype=np.uint8)
# Large component
cv2.circle(test_mask, (100, 100), 40, 1, -1)
# Small components (noise)
cv2.circle(test_mask, (50, 50), 5, 1, -1)
cv2.circle(test_mask, (150, 150), 3, 1, -1)
cv2.circle(test_mask, (30, 170), 7, 1, -1)

# Analyze components
df = cc_analyzer.analyze_component_properties(test_mask)
print(f"Component analysis:\n{df[['label', 'area', 'circularity', 'aspect_ratio']].round(3)}")

# Test filtering
filtered_mask, stats = cc_analyzer.remove_small_components(test_mask)
print(f"Filtering stats: {stats}")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(test_mask, cmap='gray')
axes[0].set_title('Original Mask')
axes[1].imshow(filtered_mask, cmap='gray')
axes[1].set_title('Filtered Mask')
for ax in axes:
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
class AnatomicalConstraintValidator:
    """Validate and enforce anatomical constraints for cardiac structures"""
    
    def __init__(self, config: PostProcessConfig):
        self.config = config
        self.cardiac_anatomy = self._define_cardiac_anatomy()
        
    def _define_cardiac_anatomy(self) -> Dict:
        """Define anatomical constraints for cardiac structures"""
        return {
            'left_ventricle': {
                'expected_area_range': (2000, 15000),
                'expected_position': 'center-left',
                'shape_constraints': {
                    'circularity_min': 0.4,
                    'aspect_ratio_range': (0.7, 1.5),
                    'solidity_min': 0.8
                }
            },
            'right_ventricle': {
                'expected_area_range': (1500, 12000),
                'expected_position': 'center-right',
                'shape_constraints': {
                    'circularity_min': 0.3,
                    'aspect_ratio_range': (0.6, 2.0),
                    'solidity_min': 0.7
                }
            },
            'myocardium': {
                'expected_area_range': (3000, 20000),
                'expected_position': 'surrounding',
                'shape_constraints': {
                    'circularity_min': 0.2,
                    'aspect_ratio_range': (0.5, 2.5),
                    'solidity_min': 0.6
                }
            }
        }
    
    def validate_cardiac_structure(self, mask: np.ndarray, structure_type: str = 'left_ventricle') -> Dict:
        """Validate a cardiac structure against anatomical constraints"""
        if structure_type not in self.cardiac_anatomy:
            raise ValueError(f"Unknown cardiac structure: {structure_type}")
        
        constraints = self.cardiac_anatomy[structure_type]
        
        # Analyze the mask
        labels = measure.label(mask)
        properties = measure.regionprops(labels)
        
        if not properties:
            return {
                'valid': False,
                'reason': 'No components found',
                'violations': ['empty_mask']
            }
        
        # Find the largest component (assumed to be the main structure)
        main_component = max(properties, key=lambda x: x.area)
        
        # Check constraints
        violations = []
        
        # Area constraint
        area = main_component.area
        expected_area = constraints['expected_area_range']
        if not (expected_area[0] <= area <= expected_area[1]):
            violations.append(f'area_violation: {area} not in {expected_area}')
        
        # Shape constraints
        shape_constraints = constraints['shape_constraints']
        
        # Circularity
        perimeter = main_component.perimeter
        circularity = 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0
        if circularity < shape_constraints['circularity_min']:
            violations.append(f'circularity_violation: {circularity:.3f} < {shape_constraints["circularity_min"]}')
        
        # Aspect ratio
        aspect_ratio = main_component.major_axis_length / main_component.minor_axis_length if main_component.minor_axis_length > 0 else 0
        aspect_range = shape_constraints['aspect_ratio_range']
        if not (aspect_range[0] <= aspect_ratio <= aspect_range[1]):
            violations.append(f'aspect_ratio_violation: {aspect_ratio:.3f} not in {aspect_range}')
        
        # Solidity
        solidity = main_component.solidity
        if solidity < shape_constraints['solidity_min']:
            violations.append(f'solidity_violation: {solidity:.3f} < {shape_constraints["solidity_min"]}')
        
        return {
            'valid': len(violations) == 0,
            'violations': violations,
            'metrics': {
                'area': area,
                'circularity': circularity,
                'aspect_ratio': aspect_ratio,
                'solidity': solidity,
                'centroid': main_component.centroid
            }
        }
    
    def validate_multi_structure(self, masks: Dict[str, np.ndarray]) -> Dict:
        """Validate multiple cardiac structures and their relationships"""
        results = {}
        
        # Validate each structure individually
        for structure_name, mask in masks.items():
            results[structure_name] = self.validate_cardiac_structure(mask, structure_name)
        
        # Check spatial relationships
        spatial_violations = self._check_spatial_relationships(masks)
        results['spatial_relationships'] = spatial_violations
        
        # Overall validation
        all_valid = all(result['valid'] for result in results.values() if isinstance(result, dict) and 'valid' in result)
        results['overall_valid'] = all_valid and len(spatial_violations) == 0
        
        return results
    
    def _check_spatial_relationships(self, masks: Dict[str, np.ndarray]) -> List[str]:
        """Check spatial relationships between cardiac structures"""
        violations = []
        
        if 'left_ventricle' in masks and 'right_ventricle' in masks:
            lv_props = measure.regionprops(measure.label(masks['left_ventricle']))
            rv_props = measure.regionprops(measure.label(masks['right_ventricle']))
            
            if lv_props and rv_props:
                lv_centroid = lv_props[0].centroid
                rv_centroid = rv_props[0].centroid
                
                # Check distance between ventricles
                distance = np.sqrt((lv_centroid[0] - rv_centroid[0])**2 + 
                                 (lv_centroid[1] - rv_centroid[1])**2)
                
                if distance > self.config.max_distance_between_structures:
                    violations.append(f'ventricles_too_far: distance={distance:.2f}')
                
                # Check relative positions (LV should be roughly to the left of RV)
                if lv_centroid[1] > rv_centroid[1]:  # LV x-coordinate > RV x-coordinate
                    violations.append('ventricles_wrong_orientation: LV not left of RV')
        
        return violations
    
    def enforce_anatomical_constraints(self, mask: np.ndarray, structure_type: str = 'left_ventricle') -> np.ndarray:
        """Enforce anatomical constraints by modifying the mask"""
        validation_result = self.validate_cardiac_structure(mask, structure_type)
        
        if validation_result['valid']:
            return mask
        
        # Apply corrections based on violations
        corrected_mask = mask.copy()
        
        for violation in validation_result['violations']:
            if 'area_violation' in violation:
                # Apply morphological operations to adjust area
                if 'too small' in violation.lower():
                    # Dilate to increase area
                    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
                    corrected_mask = cv2.dilate(corrected_mask, kernel, iterations=2)
                elif 'too large' in violation.lower():
                    # Erode to decrease area
                    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
                    corrected_mask = cv2.erode(corrected_mask, kernel, iterations=1)
            
            elif 'circularity_violation' in violation:
                # Apply closing to improve circularity
                kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
                corrected_mask = cv2.morphologyEx(corrected_mask, cv2.MORPH_CLOSE, kernel)
        
        return corrected_mask.astype(np.uint8)
    
    def get_anatomical_priors(self, image_shape: Tuple[int, int]) -> Dict[str, np.ndarray]:
        """Generate anatomical priors for cardiac structures"""
        height, width = image_shape
        center_y, center_x = height // 2, width // 2
        
        priors = {}
        
        # Left ventricle prior (center-left)
        lv_prior = np.zeros(image_shape, dtype=np.float32)
        lv_center = (center_y, center_x - width // 6)
        y_grid, x_grid = np.ogrid[:height, :width]
        lv_mask = ((y_grid - lv_center[0])**2 + (x_grid - lv_center[1])**2) <= (width//8)**2
        lv_prior[lv_mask] = 1.0
        priors['left_ventricle'] = lv_prior
        
        # Right ventricle prior (center-right)
        rv_prior = np.zeros(image_shape, dtype=np.float32)
        rv_center = (center_y, center_x + width // 6)
        rv_mask = ((y_grid - rv_center[0])**2 + (x_grid - rv_center[1])**2) <= (width//10)**2
        rv_prior[rv_mask] = 1.0
        priors['right_ventricle'] = rv_prior
        
        # Myocardium prior (surrounding)
        myo_prior = np.zeros(image_shape, dtype=np.float32)
        outer_mask = ((y_grid - center_y)**2 + (x_grid - center_x)**2) <= (width//4)**2
        inner_mask = ((y_grid - center_y)**2 + (x_grid - center_x)**2) <= (width//8)**2
        myo_prior[outer_mask & ~inner_mask] = 1.0
        priors['myocardium'] = myo_prior
        
        return priors

# Test the anatomical constraint validator
constraint_validator = AnatomicalConstraintValidator(config)
print("✅ Anatomical Constraint Validator initialized!")
print(f"Supported cardiac structures: {list(constraint_validator.cardiac_anatomy.keys())}")

# Create test masks for validation
test_image_shape = (256, 256)
test_lv_mask = np.zeros(test_image_shape, dtype=np.uint8)
cv2.circle(test_lv_mask, (128, 100), 50, 1, -1)  # Left ventricle

test_rv_mask = np.zeros(test_image_shape, dtype=np.uint8)
cv2.circle(test_rv_mask, (128, 180), 35, 1, -1)  # Right ventricle

# Validate structures
lv_validation = constraint_validator.validate_cardiac_structure(test_lv_mask, 'left_ventricle')
rv_validation = constraint_validator.validate_cardiac_structure(test_rv_mask, 'right_ventricle')

print(f"LV validation: Valid={lv_validation['valid']}, Violations={len(lv_validation['violations'])}")
print(f"RV validation: Valid={rv_validation['valid']}, Violations={len(rv_validation['violations'])}")

# Generate anatomical priors
priors = constraint_validator.get_anatomical_priors(test_image_shape)

# Visualize priors
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, (name, prior) in enumerate(priors.items()):
    axes[i].imshow(prior, cmap='hot', alpha=0.7)
    axes[i].set_title(f'{name.replace("_", " ").title()} Prior')
    axes[i].axis('off')
plt.tight_layout()
plt.show()

In [None]:
class BoundaryRefinement:
    """Advanced boundary refinement and smoothing for cardiac segmentation"""
    
    def __init__(self, config: PostProcessConfig):
        self.config = config
        
    def gaussian_smoothing(self, mask: np.ndarray) -> np.ndarray:
        """Apply Gaussian smoothing to boundaries"""
        # Convert to float for smoothing
        mask_float = mask.astype(np.float32)
        
        # Apply Gaussian filter
        smoothed = filters.gaussian(mask_float, sigma=self.config.gaussian_sigma)
        
        # Threshold back to binary
        return (smoothed > 0.5).astype(np.uint8)
    
    def contour_smoothing(self, mask: np.ndarray) -> np.ndarray:
        """Smooth contours using approximation"""
        # Find contours
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        # Create output mask
        smoothed_mask = np.zeros_like(mask)
        
        for contour in contours:
            # Approximate contour
            epsilon = self.config.contour_epsilon * cv2.arcLength(contour, True)
            approx_contour = cv2.approxPolyDP(contour, epsilon, True)
            
            # Fill the approximated contour
            cv2.fillPoly(smoothed_mask, [approx_contour], 1)
        
        return smoothed_mask
    
    def active_contour_refinement(self, mask: np.ndarray, image: np.ndarray = None) -> np.ndarray:
        """Refine boundaries using active contours (snakes)"""
        from skimage.segmentation import active_contour
        
        # If no image provided, use the mask itself
        if image is None:
            image = mask.astype(np.float32)
        
        # Find contours
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        refined_mask = np.zeros_like(mask)
        
        for contour in contours:
            if len(contour) < 10:  # Skip very small contours
                continue
                
            # Convert contour to snake format
            snake = contour.squeeze().astype(np.float32)
            
            # Apply active contour
            try:
                refined_snake = active_contour(
                    image, snake, 
                    alpha=0.015, beta=10, gamma=0.001,
                    max_iterations=100
                )
                
                # Convert back to contour and fill
                refined_contour = refined_snake.astype(np.int32).reshape(-1, 1, 2)
                cv2.fillPoly(refined_mask, [refined_contour], 1)
                
            except Exception as e:
                # If active contour fails, use original contour
                cv2.fillPoly(refined_mask, [contour], 1)
        
        return refined_mask
    
    def iterative_smoothing(self, mask: np.ndarray) -> np.ndarray:
        """Apply iterative smoothing operations"""
        current_mask = mask.copy()
        
        for i in range(self.config.smooth_iterations):
            # Alternate between different smoothing methods
            if i % 2 == 0:
                current_mask = self.gaussian_smoothing(current_mask)
            else:
                current_mask = self.contour_smoothing(current_mask)
        
        return current_mask
    
    def edge_preserving_smoothing(self, mask: np.ndarray, image: np.ndarray = None) -> np.ndarray:
        """Apply edge-preserving smoothing"""
        if image is None:
            # Use bilateral filter on the mask
            mask_8bit = (mask * 255).astype(np.uint8)
            smoothed = cv2.bilateralFilter(mask_8bit, d=9, sigmaColor=75, sigmaSpace=75)
            return (smoothed > 127).astype(np.uint8)
        else:
            # Use image gradients to preserve edges
            # Compute gradients
            grad_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
            grad_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
            gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
            
            # Create edge-preserving weights
            weights = np.exp(-gradient_magnitude / (2 * self.config.gaussian_sigma**2))
            
            # Apply weighted smoothing
            mask_float = mask.astype(np.float32)
            smoothed = filters.gaussian(mask_float * weights, sigma=self.config.gaussian_sigma)
            
            return (smoothed > 0.5).astype(np.uint8)
    
    def topology_preserving_smoothing(self, mask: np.ndarray) -> np.ndarray:
        """Apply smoothing while preserving topology"""
        # Use skimage's topology-preserving smoothing
        from skimage.morphology import binary_opening, binary_closing
        
        # Apply conservative smoothing operations
        smoothed = binary_opening(mask.astype(bool), morphology.disk(2))
        smoothed = binary_closing(smoothed, morphology.disk(3))
        
        return smoothed.astype(np.uint8)
    
    def boundary_regularization(self, mask: np.ndarray, lambda_smooth: float = 0.1) -> np.ndarray:
        """Apply boundary regularization using energy minimization"""
        # Convert to signed distance function
        distance_inside = ndimage.distance_transform_edt(mask)
        distance_outside = ndimage.distance_transform_edt(1 - mask)
        signed_distance = distance_inside - distance_outside
        
        # Apply regularization (simplified version)
        regularized = filters.gaussian(signed_distance, sigma=lambda_smooth)
        
        # Convert back to binary mask
        return (regularized > 0).astype(np.uint8)
    
    def multi_scale_refinement(self, mask: np.ndarray, scales: List[float] = [0.5, 1.0, 2.0]) -> np.ndarray:
        """Apply multi-scale boundary refinement"""
        refined_masks = []
        
        for scale in scales:
            # Adjust sigma based on scale
            scaled_sigma = self.config.gaussian_sigma * scale
            
            # Apply Gaussian smoothing
            mask_float = mask.astype(np.float32)
            smoothed = filters.gaussian(mask_float, sigma=scaled_sigma)
            refined_masks.append((smoothed > 0.5).astype(np.uint8))
        
        # Combine results using majority voting
        combined = np.stack(refined_masks, axis=-1)
        return (np.sum(combined, axis=-1) > len(scales) // 2).astype(np.uint8)
    
    def calculate_boundary_smoothness(self, mask: np.ndarray) -> float:
        """Calculate boundary smoothness metric"""
        # Find contours
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if not contours:
            return 0.0
        
        smoothness_scores = []
        
        for contour in contours:
            if len(contour) < 10:
                continue
                
            # Calculate curvature along contour
            points = contour.squeeze()
            if len(points.shape) == 1:
                continue
                
            # Compute differences
            diff1 = np.diff(points, axis=0)
            diff2 = np.diff(diff1, axis=0)
            
            # Calculate curvature
            curvature = np.cross(diff1[:-1], diff2) / (np.linalg.norm(diff1[:-1], axis=1)**3 + 1e-8)
            
            # Smoothness is inverse of curvature variation
            smoothness = 1.0 / (np.std(curvature) + 1e-8)
            smoothness_scores.append(smoothness)
        
        return np.mean(smoothness_scores) if smoothness_scores else 0.0

# Test the boundary refinement
boundary_refiner = BoundaryRefinement(config)
print("✅ Boundary Refinement initialized!")

# Create a test mask with rough boundaries
test_mask = np.zeros((200, 200), dtype=np.uint8)
# Create a rough heart-like shape
points = np.array([
    [100, 60], [120, 80], [140, 100], [130, 140], [100, 160],
    [70, 140], [60, 100], [80, 80]
], dtype=np.int32)
cv2.fillPoly(test_mask, [points], 1)

# Add some noise
noise = np.random.random(test_mask.shape) > 0.95
test_mask = test_mask | noise.astype(np.uint8)

# Apply different smoothing methods
gaussian_smoothed = boundary_refiner.gaussian_smoothing(test_mask)
contour_smoothed = boundary_refiner.contour_smoothing(test_mask)
iterative_smoothed = boundary_refiner.iterative_smoothing(test_mask)
multi_scale_smoothed = boundary_refiner.multi_scale_refinement(test_mask)

# Calculate smoothness metrics
original_smoothness = boundary_refiner.calculate_boundary_smoothness(test_mask)
gaussian_smoothness = boundary_refiner.calculate_boundary_smoothness(gaussian_smoothed)
contour_smoothness = boundary_refiner.calculate_boundary_smoothness(contour_smoothed)

print(f"Boundary smoothness comparison:")
print(f"Original: {original_smoothness:.3f}")
print(f"Gaussian: {gaussian_smoothness:.3f}")
print(f"Contour: {contour_smoothness:.3f}")

# Visualize results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

masks = [test_mask, gaussian_smoothed, contour_smoothed, 
         iterative_smoothed, multi_scale_smoothed, test_mask]
titles = ['Original', 'Gaussian Smoothed', 'Contour Smoothed',
          'Iterative Smoothed', 'Multi-scale Smoothed', 'Original Overlay']

for i, (mask, title) in enumerate(zip(masks, titles)):
    if i == 5:  # Overlay
        axes[i].imshow(test_mask, cmap='gray', alpha=0.5)
        axes[i].contour(gaussian_smoothed, colors='red', linewidths=2, alpha=0.8)
        axes[i].set_title('Original vs Smoothed Boundary')
    else:
        axes[i].imshow(mask, cmap='gray')
        axes[i].set_title(title)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
class PostProcessingPipeline:
    """Complete post-processing pipeline for cardiac segmentation"""
    
    def __init__(self, config: PostProcessConfig):
        self.config = config
        self.morph_processor = MorphologicalProcessor(config)
        self.cc_analyzer = ConnectedComponentAnalyzer(config)
        self.constraint_validator = AnatomicalConstraintValidator(config)
        self.boundary_refiner = BoundaryRefinement(config)
        
        # Pipeline steps configuration
        self.pipeline_steps = [
            'noise_removal',
            'morphological_operations',
            'connected_component_analysis',
            'anatomical_validation',
            'boundary_refinement',
            'quality_assessment'
        ]
        
    def process_single_mask(self, 
                           mask: np.ndarray, 
                           image: np.ndarray = None,
                           structure_type: str = 'left_ventricle',
                           steps: List[str] = None) -> Tuple[np.ndarray, PostProcessingMetrics]:
        """Process a single segmentation mask through the complete pipeline"""
        
        if steps is None:
            steps = self.pipeline_steps
            
        processed_mask = mask.copy()
        metrics = PostProcessingMetrics(
            components_removed=0,
            area_change=0.0,
            boundary_smoothness=0.0,
            confidence_improvement=0.0,
            processing_time=0.0,
            quality_score=0.0
        )
        
        start_time = time.time()
        original_area = np.sum(mask)
        
        # Step 1: Noise Removal
        if 'noise_removal' in steps:
            processed_mask = self._remove_noise(processed_mask)
        
        # Step 2: Morphological Operations
        if 'morphological_operations' in steps:
            processed_mask = self._apply_morphological_operations(processed_mask)
        
        # Step 3: Connected Component Analysis
        if 'connected_component_analysis' in steps:
            processed_mask, cc_stats = self.cc_analyzer.remove_small_components(processed_mask)
            processed_mask, _ = self.cc_analyzer.keep_largest_components(processed_mask)
            metrics.components_removed = cc_stats['removed_components']
        
        # Step 4: Anatomical Validation
        if 'anatomical_validation' in steps:
            processed_mask = self.constraint_validator.enforce_anatomical_constraints(
                processed_mask, structure_type
            )
        
        # Step 5: Boundary Refinement
        if 'boundary_refinement' in steps:
            processed_mask = self.boundary_refiner.iterative_smoothing(processed_mask)
            metrics.boundary_smoothness = self.boundary_refiner.calculate_boundary_smoothness(processed_mask)
        
        # Step 6: Quality Assessment
        if 'quality_assessment' in steps:
            metrics.quality_score = self._calculate_quality_score(mask, processed_mask)
        
        # Calculate final metrics
        final_area = np.sum(processed_mask)
        metrics.area_change = (final_area - original_area) / original_area if original_area > 0 else 0
        metrics.processing_time = time.time() - start_time
        
        return processed_mask, metrics
    
    def process_multi_structure(self, 
                               masks: Dict[str, np.ndarray],
                               image: np.ndarray = None) -> Tuple[Dict[str, np.ndarray], Dict]:
        """Process multiple cardiac structures with inter-structure constraints"""
        
        processed_masks = {}
        all_metrics = {}
        
        # Process each structure individually
        for structure_name, mask in masks.items():
            processed_mask, metrics = self.process_single_mask(
                mask, image, structure_name
            )
            processed_masks[structure_name] = processed_mask
            all_metrics[structure_name] = metrics
        
        # Apply inter-structure constraints
        processed_masks = self._apply_inter_structure_constraints(processed_masks)
        
        # Validate spatial relationships
        spatial_validation = self.constraint_validator.validate_multi_structure(processed_masks)
        all_metrics['spatial_validation'] = spatial_validation
        
        return processed_masks, all_metrics
    
    def _remove_noise(self, mask: np.ndarray) -> np.ndarray:
        """Remove noise from segmentation mask"""
        # Remove small isolated pixels
        mask = self.morph_processor.opening(mask, 'circular')
        
        # Remove components touching border
        mask, _ = self.cc_analyzer.remove_border_components(mask)
        
        # Apply median filter to remove salt-and-pepper noise
        mask = cv2.medianBlur(mask.astype(np.uint8), 3)
        
        return mask
    
    def _apply_morphological_operations(self, mask: np.ndarray) -> np.ndarray:
        """Apply sequence of morphological operations"""
        # Opening to remove small noise
        mask = self.morph_processor.opening(mask, 'circular')
        
        # Closing to fill small holes
        mask = self.morph_processor.closing(mask, 'circular')
        
        # Adaptive morphology based on component size
        mask = self.morph_processor.adaptive_morphology(mask)
        
        return mask
    
    def _apply_inter_structure_constraints(self, masks: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        """Apply constraints between different cardiac structures"""
        processed_masks = masks.copy()
        
        # Ensure no overlap between ventricles
        if 'left_ventricle' in masks and 'right_ventricle' in masks:
            lv_mask = processed_masks['left_ventricle']
            rv_mask = processed_masks['right_ventricle']
            
            # Remove overlapping regions
            overlap = lv_mask & rv_mask
            if np.sum(overlap) > 0:
                # Assign overlap to the structure with higher confidence
                # For simplicity, assign to left ventricle
                processed_masks['right_ventricle'] = rv_mask & ~overlap
        
        # Ensure myocardium surrounds ventricles
        if 'myocardium' in masks and ('left_ventricle' in masks or 'right_ventricle' in masks):
            myo_mask = processed_masks['myocardium']
            ventricle_masks = []
            
            if 'left_ventricle' in masks:
                ventricle_masks.append(processed_masks['left_ventricle'])
            if 'right_ventricle' in masks:
                ventricle_masks.append(processed_masks['right_ventricle'])
            
            # Combine ventricle masks
            combined_ventricles = np.zeros_like(myo_mask)
            for v_mask in ventricle_masks:
                combined_ventricles = combined_ventricles | v_mask
            
            # Ensure myocardium doesn't overlap with ventricles
            processed_masks['myocardium'] = myo_mask & ~combined_ventricles
        
        return processed_masks
    
    def _calculate_quality_score(self, original_mask: np.ndarray, processed_mask: np.ndarray) -> float:
        """Calculate overall quality score for post-processing"""
        scores = []
        
        # Area preservation score
        original_area = np.sum(original_mask)
        processed_area = np.sum(processed_mask)
        area_ratio = processed_area / original_area if original_area > 0 else 0
        area_score = 1.0 - abs(1.0 - area_ratio)  # Closer to 1.0 is better
        scores.append(area_score)
        
        # Boundary smoothness score
        smoothness = self.boundary_refiner.calculate_boundary_smoothness(processed_mask)
        smoothness_score = min(smoothness / 10.0, 1.0)  # Normalize
        scores.append(smoothness_score)
        
        # Connectivity score (prefer fewer components)
        labels = measure.label(processed_mask)
        n_components = len(measure.regionprops(labels))
        connectivity_score = max(0, 1.0 - (n_components - 1) * 0.2)
        scores.append(connectivity_score)
        
        # Compactness score
        if processed_area > 0:
            contours, _ = cv2.findContours(processed_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                perimeter = cv2.arcLength(contours[0], True)
                compactness = 4 * np.pi * processed_area / (perimeter ** 2) if perimeter > 0 else 0
                compactness_score = min(compactness, 1.0)
                scores.append(compactness_score)
        
        return np.mean(scores)
    
    def batch_process(self, 
                     masks: List[np.ndarray], 
                     images: List[np.ndarray] = None,
                     structure_types: List[str] = None) -> Tuple[List[np.ndarray], List[PostProcessingMetrics]]:
        """Process multiple masks in batch"""
        
        if images is None:
            images = [None] * len(masks)
        if structure_types is None:
            structure_types = ['left_ventricle'] * len(masks)
        
        processed_masks = []
        all_metrics = []
        
        for i, (mask, image, structure_type) in enumerate(zip(masks, images, structure_types)):
            try:
                processed_mask, metrics = self.process_single_mask(mask, image, structure_type)
                processed_masks.append(processed_mask)
                all_metrics.append(metrics)
            except Exception as e:
                print(f"Error processing mask {i}: {e}")
                processed_masks.append(mask)  # Return original on error
                all_metrics.append(PostProcessingMetrics(0, 0, 0, 0, 0, 0))
        
        return processed_masks, all_metrics
    
    def get_pipeline_summary(self) -> Dict:
        """Get summary of pipeline configuration and capabilities"""
        return {
            'pipeline_steps': self.pipeline_steps,
            'morphological_operations': list(self.morph_processor.kernels.keys()),
            'supported_structures': list(self.constraint_validator.cardiac_anatomy.keys()),
            'config': self.config.__dict__
        }

# Test the complete post-processing pipeline
import time
pipeline = PostProcessingPipeline(config)
print("✅ Post-processing Pipeline initialized!")
print("Pipeline steps:", pipeline.pipeline_steps)

# Create test data
test_mask = np.zeros((256, 256), dtype=np.uint8)
cv2.circle(test_mask, (128, 128), 60, 1, -1)  # Main structure
cv2.circle(test_mask, (50, 50), 5, 1, -1)     # Small noise
cv2.circle(test_mask, (200, 200), 3, 1, -1)   # Small noise

# Add some boundary roughness
noise_mask = np.random.random(test_mask.shape) > 0.98
boundary_noise = test_mask & noise_mask.astype(np.uint8)
test_mask = test_mask | boundary_noise

# Process single mask
processed_mask, metrics = pipeline.process_single_mask(test_mask, structure_type='left_ventricle')

print(f"Processing Results:")
print(f"  Components removed: {metrics.components_removed}")
print(f"  Area change: {metrics.area_change:.3f}")
print(f"  Boundary smoothness: {metrics.boundary_smoothness:.3f}")
print(f"  Quality score: {metrics.quality_score:.3f}")
print(f"  Processing time: {metrics.processing_time:.3f}s")

# Test multi-structure processing
test_masks = {
    'left_ventricle': test_mask,
    'right_ventricle': np.roll(test_mask, 50, axis=1)  # Shifted version
}

processed_multi, multi_metrics = pipeline.process_multi_structure(test_masks)

print(f"\nMulti-structure processing:")
for structure, metrics in multi_metrics.items():
    if isinstance(metrics, PostProcessingMetrics):
        print(f"  {structure}: Quality={metrics.quality_score:.3f}, Time={metrics.processing_time:.3f}s")

# Visualize results
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

axes[0, 0].imshow(test_mask, cmap='gray')
axes[0, 0].set_title('Original Mask')
axes[0, 0].axis('off')

axes[0, 1].imshow(processed_mask, cmap='gray')
axes[0, 1].set_title('Post-processed Mask')
axes[0, 1].axis('off')

# Overlay comparison
axes[1, 0].imshow(test_mask, cmap='gray', alpha=0.5)
axes[1, 0].contour(processed_mask, colors='red', linewidths=2)
axes[1, 0].set_title('Boundary Comparison')
axes[1, 0].axis('off')

# Multi-structure result
multi_combined = processed_multi['left_ventricle'] + processed_multi['right_ventricle'] * 2
axes[1, 1].imshow(multi_combined, cmap='viridis')
axes[1, 1].set_title('Multi-structure Result')
axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
class PostProcessingVisualizer:
    """Advanced visualization tools for post-processing analysis"""
    
    def __init__(self):
        self.color_maps = {
            'original': 'Blues',
            'processed': 'Reds',
            'difference': 'RdBu',
            'overlay': 'viridis'
        }
    
    def plot_processing_comparison(self, 
                                 original_mask: np.ndarray,
                                 processed_mask: np.ndarray,
                                 image: np.ndarray = None,
                                 title: str = "Post-processing Comparison"):
        """Create comprehensive comparison visualization"""
        
        fig = plt.figure(figsize=(16, 12))
        
        # Original mask
        plt.subplot(3, 3, 1)
        plt.imshow(original_mask, cmap='gray')
        plt.title('Original Mask')
        plt.axis('off')
        
        # Processed mask
        plt.subplot(3, 3, 2)
        plt.imshow(processed_mask, cmap='gray')
        plt.title('Processed Mask')
        plt.axis('off')
        
        # Difference
        plt.subplot(3, 3, 3)
        difference = processed_mask.astype(int) - original_mask.astype(int)
        plt.imshow(difference, cmap='RdBu', vmin=-1, vmax=1)
        plt.title('Difference (Red: Added, Blue: Removed)')
        plt.colorbar(shrink=0.8)
        plt.axis('off')
        
        # Overlay with original image
        if image is not None:
            plt.subplot(3, 3, 4)
            plt.imshow(image, cmap='gray', alpha=0.7)
            plt.contour(original_mask, colors='blue', linewidths=2, alpha=0.8)
            plt.contour(processed_mask, colors='red', linewidths=2, alpha=0.8)
            plt.title('Overlay on Original Image')
            plt.axis('off')
        
        # Boundary comparison
        plt.subplot(3, 3, 5)
        # Find contours
        orig_contours, _ = cv2.findContours(original_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        proc_contours, _ = cv2.findContours(processed_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        boundary_img = np.zeros((*original_mask.shape, 3), dtype=np.uint8)
        if orig_contours:
            cv2.drawContours(boundary_img, orig_contours, -1, (255, 0, 0), 2)  # Blue
        if proc_contours:
            cv2.drawContours(boundary_img, proc_contours, -1, (0, 255, 0), 2)  # Green
        
        plt.imshow(boundary_img)
        plt.title('Boundary Comparison (Blue: Original, Green: Processed)')
        plt.axis('off')
        
        # Morphological analysis
        plt.subplot(3, 3, 6)
        # Apply morphological gradient to show boundaries
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        orig_gradient = cv2.morphologyEx(original_mask, cv2.MORPH_GRADIENT, kernel)
        proc_gradient = cv2.morphologyEx(processed_mask, cv2.MORPH_GRADIENT, kernel)
        
        combined_gradient = np.stack([orig_gradient, proc_gradient, np.zeros_like(orig_gradient)], axis=-1)
        plt.imshow(combined_gradient)
        plt.title('Morphological Gradient (Red: Original, Green: Processed)')
        plt.axis('off')
        
        # Area analysis
        plt.subplot(3, 3, 7)
        areas = [np.sum(original_mask), np.sum(processed_mask)]
        labels = ['Original', 'Processed']
        colors = ['lightblue', 'lightcoral']
        bars = plt.bar(labels, areas, color=colors)
        plt.title('Area Comparison')
        plt.ylabel('Area (pixels)')
        
        # Add value labels on bars
        for bar, area in zip(bars, areas):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(areas)*0.01,
                    f'{area}', ha='center', va='bottom')
        
        # Component analysis
        plt.subplot(3, 3, 8)
        orig_labels = measure.label(original_mask)
        proc_labels = measure.label(processed_mask)
        orig_components = len(measure.regionprops(orig_labels))
        proc_components = len(measure.regionprops(proc_labels))
        
        components = [orig_components, proc_components]
        bars = plt.bar(labels, components, color=colors)
        plt.title('Connected Components')
        plt.ylabel('Number of Components')
        
        for bar, comp in zip(bars, components):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(components)*0.05,
                    f'{comp}', ha='center', va='bottom')
        
        # Quality metrics
        plt.subplot(3, 3, 9)
        # Calculate various quality metrics
        boundary_refiner = BoundaryRefinement(PostProcessConfig())
        orig_smoothness = boundary_refiner.calculate_boundary_smoothness(original_mask)
        proc_smoothness = boundary_refiner.calculate_boundary_smoothness(processed_mask)
        
        metrics = ['Smoothness', 'Components']
        orig_values = [orig_smoothness, orig_components]
        proc_values = [proc_smoothness, proc_components]
        
        x = np.arange(len(metrics))
        width = 0.35
        
        plt.bar(x - width/2, orig_values, width, label='Original', color='lightblue')
        plt.bar(x + width/2, proc_values, width, label='Processed', color='lightcoral')
        
        plt.xlabel('Metrics')
        plt.title('Quality Metrics Comparison')
        plt.xticks(x, metrics)
        plt.legend()
        
        plt.suptitle(title, fontsize=16)
        plt.tight_layout()
        plt.show()
    
    def plot_pipeline_metrics(self, metrics_list: List[PostProcessingMetrics], titles: List[str] = None):
        """Plot comprehensive metrics from pipeline processing"""
        
        if titles is None:
            titles = [f'Sample {i+1}' for i in range(len(metrics_list))]
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        
        # Extract metrics
        components_removed = [m.components_removed for m in metrics_list]
        area_changes = [m.area_change for m in metrics_list]
        boundary_smoothness = [m.boundary_smoothness for m in metrics_list]
        processing_times = [m.processing_time for m in metrics_list]
        quality_scores = [m.quality_score for m in metrics_list]
        
        # Components removed
        axes[0, 0].bar(range(len(components_removed)), components_removed, color='lightcoral')
        axes[0, 0].set_title('Components Removed')
        axes[0, 0].set_xlabel('Sample')
        axes[0, 0].set_ylabel('Count')
        axes[0, 0].set_xticks(range(len(titles)))
        axes[0, 0].set_xticklabels(titles, rotation=45)
        
        # Area changes
        axes[0, 1].bar(range(len(area_changes)), area_changes, 
                      color=['green' if x >= 0 else 'red' for x in area_changes])
        axes[0, 1].set_title('Area Changes')
        axes[0, 1].set_xlabel('Sample')
        axes[0, 1].set_ylabel('Relative Change')\n        axes[0, 1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
        axes[0, 1].set_xticks(range(len(titles)))
        axes[0, 1].set_xticklabels(titles, rotation=45)
        
        # Boundary smoothness
        axes[0, 2].bar(range(len(boundary_smoothness)), boundary_smoothness, color='lightblue')
        axes[0, 2].set_title('Boundary Smoothness')
        axes[0, 2].set_xlabel('Sample')
        axes[0, 2].set_ylabel('Smoothness Score')
        axes[0, 2].set_xticks(range(len(titles)))
        axes[0, 2].set_xticklabels(titles, rotation=45)
        
        # Processing times
        axes[1, 0].bar(range(len(processing_times)), processing_times, color='orange')
        axes[1, 0].set_title('Processing Times')
        axes[1, 0].set_xlabel('Sample')
        axes[1, 0].set_ylabel('Time (seconds)')
        axes[1, 0].set_xticks(range(len(titles)))
        axes[1, 0].set_xticklabels(titles, rotation=45)
        
        # Quality scores
        axes[1, 1].bar(range(len(quality_scores)), quality_scores, color='lightgreen')
        axes[1, 1].set_title('Quality Scores')
        axes[1, 1].set_xlabel('Sample')
        axes[1, 1].set_ylabel('Quality Score (0-1)')
        axes[1, 1].set_ylim(0, 1)
        axes[1, 1].set_xticks(range(len(titles)))
        axes[1, 1].set_xticklabels(titles, rotation=45)
        
        # Summary scatter plot
        axes[1, 2].scatter(quality_scores, processing_times, c=boundary_smoothness, 
                          cmap='viridis', s=100, alpha=0.7)
        axes[1, 2].set_xlabel('Quality Score')
        axes[1, 2].set_ylabel('Processing Time (s)')
        axes[1, 2].set_title('Quality vs Performance')
        
        # Add colorbar for smoothness
        scatter = axes[1, 2].scatter(quality_scores, processing_times, c=boundary_smoothness, 
                                   cmap='viridis', s=100, alpha=0.7)
        plt.colorbar(scatter, ax=axes[1, 2], label='Boundary Smoothness')
        
        plt.tight_layout()
        plt.show()
    
    def create_interactive_comparison(self, 
                                    original_mask: np.ndarray,
                                    processed_mask: np.ndarray,
                                    image: np.ndarray = None):
        """Create interactive Plotly visualization"""
        
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Original Mask', 'Processed Mask', 'Difference', 'Overlay'),
            specs=[[{"type": "image"}, {"type": "image"}],
                   [{"type": "image"}, {"type": "image"}]]
        )
        
        # Original mask
        fig.add_trace(
            go.Heatmap(z=original_mask, colorscale='Blues', showscale=False),
            row=1, col=1
        )
        
        # Processed mask
        fig.add_trace(
            go.Heatmap(z=processed_mask, colorscale='Reds', showscale=False),
            row=1, col=2
        )
        
        # Difference
        difference = processed_mask.astype(int) - original_mask.astype(int)
        fig.add_trace(
            go.Heatmap(z=difference, colorscale='RdBu', zmid=0, showscale=True),
            row=2, col=1
        )
        
        # Overlay
        if image is not None:
            overlay = image * 0.7 + (original_mask * 0.15) + (processed_mask * 0.15)
        else:
            overlay = original_mask * 0.5 + processed_mask * 0.5
        
        fig.add_trace(
            go.Heatmap(z=overlay, colorscale='Viridis', showscale=False),
            row=2, col=2
        )
        
        fig.update_layout(
            title="Interactive Post-processing Comparison",
            height=800,
            showlegend=False
        )
        
        # Remove axis ticks
        fig.update_xaxes(showticklabels=False)
        fig.update_yaxes(showticklabels=False)
        
        return fig

class QualityAssessment:
    """Comprehensive quality assessment for post-processed segmentations"""
    
    def __init__(self):
        self.metrics = {
            'dice_coefficient': self._dice_coefficient,
            'jaccard_index': self._jaccard_index,
            'hausdorff_distance': self._hausdorff_distance,
            'mean_surface_distance': self._mean_surface_distance,
            'boundary_smoothness': self._boundary_smoothness,
            'topological_consistency': self._topological_consistency
        }
    
    def _dice_coefficient(self, mask1: np.ndarray, mask2: np.ndarray) -> float:
        """Calculate Dice coefficient"""
        intersection = np.sum(mask1 & mask2)
        return 2.0 * intersection / (np.sum(mask1) + np.sum(mask2) + 1e-8)
    
    def _jaccard_index(self, mask1: np.ndarray, mask2: np.ndarray) -> float:
        """Calculate Jaccard index"""
        intersection = np.sum(mask1 & mask2)
        union = np.sum(mask1 | mask2)
        return intersection / (union + 1e-8)
    
    def _hausdorff_distance(self, mask1: np.ndarray, mask2: np.ndarray) -> float:
        """Calculate Hausdorff distance"""
        from scipy.spatial.distance import directed_hausdorff
        
        # Get boundary points
        contours1, _ = cv2.findContours(mask1, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours2, _ = cv2.findContours(mask2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if not contours1 or not contours2:
            return float('inf')
        
        points1 = contours1[0].squeeze()
        points2 = contours2[0].squeeze()
        
        if len(points1.shape) == 1:
            points1 = points1.reshape(1, -1)
        if len(points2.shape) == 1:
            points2 = points2.reshape(1, -1)
        
        return max(directed_hausdorff(points1, points2)[0], 
                  directed_hausdorff(points2, points1)[0])
    
    def _mean_surface_distance(self, mask1: np.ndarray, mask2: np.ndarray) -> float:
        """Calculate mean surface distance"""
        # Distance transforms
        dist1 = ndimage.distance_transform_edt(~mask1.astype(bool))
        dist2 = ndimage.distance_transform_edt(~mask2.astype(bool))
        
        # Surface points
        surface1 = mask1 & ~ndimage.binary_erosion(mask1)
        surface2 = mask2 & ~ndimage.binary_erosion(mask2)
        
        if not np.any(surface1) or not np.any(surface2):
            return float('inf')
        
        # Mean distances
        mean_dist1to2 = np.mean(dist2[surface1])
        mean_dist2to1 = np.mean(dist1[surface2])
        
        return (mean_dist1to2 + mean_dist2to1) / 2.0
    
    def _boundary_smoothness(self, mask: np.ndarray) -> float:
        """Calculate boundary smoothness"""
        boundary_refiner = BoundaryRefinement(PostProcessConfig())
        return boundary_refiner.calculate_boundary_smoothness(mask)
    
    def _topological_consistency(self, mask: np.ndarray) -> float:
        """Calculate topological consistency (Euler number)"""
        # Calculate Euler number (connected components - holes)
        labels = measure.label(mask)
        n_components = len(measure.regionprops(labels))
        
        # Approximate holes by looking at the background components inside mask
        filled_mask = ndimage.binary_fill_holes(mask)
        holes = np.sum(filled_mask) - np.sum(mask)
        
        # Simple topological score
        return max(0, 1.0 - abs(n_components - 1) * 0.2 - holes * 0.001)
    
    def evaluate_post_processing(self, 
                               original_mask: np.ndarray, 
                               processed_mask: np.ndarray,
                               ground_truth: np.ndarray = None) -> Dict[str, float]:
        """Comprehensive evaluation of post-processing results"""
        
        results = {}
        
        # Compare processed vs original
        for metric_name, metric_func in self.metrics.items():
            if metric_name in ['boundary_smoothness', 'topological_consistency']:
                # These metrics are calculated on single masks
                results[f'original_{metric_name}'] = metric_func(original_mask)
                results[f'processed_{metric_name}'] = metric_func(processed_mask)
                results[f'{metric_name}_improvement'] = (
                    results[f'processed_{metric_name}'] - results[f'original_{metric_name}']
                )
            else:
                # These metrics compare two masks
                results[f'original_vs_processed_{metric_name}'] = metric_func(original_mask, processed_mask)
        
        # If ground truth is available, compare both versions against it
        if ground_truth is not None:
            for metric_name, metric_func in self.metrics.items():
                if metric_name not in ['boundary_smoothness', 'topological_consistency']:
                    results[f'original_vs_gt_{metric_name}'] = metric_func(original_mask, ground_truth)
                    results[f'processed_vs_gt_{metric_name}'] = metric_func(processed_mask, ground_truth)
                    
                    # Calculate improvement
                    improvement = (results[f'processed_vs_gt_{metric_name}'] - 
                                 results[f'original_vs_gt_{metric_name}'])
                    results[f'gt_{metric_name}_improvement'] = improvement
        
        return results

# Test the visualization and quality assessment tools
visualizer = PostProcessingVisualizer()
quality_assessor = QualityAssessment()

print("✅ Visualization and Quality Assessment tools initialized!")

# Create test data with different levels of processing
test_masks = []
processed_masks = []
metrics_list = []

for i in range(3):
    # Create test mask with varying noise levels
    test_mask = np.zeros((200, 200), dtype=np.uint8)
    cv2.circle(test_mask, (100, 100), 50 + i*10, 1, -1)
    
    # Add noise
    noise_level = (i + 1) * 0.02
    noise = np.random.random(test_mask.shape) > (1 - noise_level)
    test_mask = test_mask | noise.astype(np.uint8)
    
    # Process mask
    processed_mask, metrics = pipeline.process_single_mask(test_mask)
    
    test_masks.append(test_mask)
    processed_masks.append(processed_mask)
    metrics_list.append(metrics)

# Test quality assessment
print("\nQuality Assessment Results:")
for i, (orig, proc) in enumerate(zip(test_masks, processed_masks)):
    quality_results = quality_assessor.evaluate_post_processing(orig, proc)
    print(f"\nSample {i+1}:")
    for metric, value in quality_results.items():
        if 'improvement' in metric:
            print(f"  {metric}: {value:.4f}")

# Plot pipeline metrics
visualizer.plot_pipeline_metrics(metrics_list, [f'Sample {i+1}' for i in range(3)])

# Show detailed comparison for first sample
visualizer.plot_processing_comparison(test_masks[0], processed_masks[0], 
                                    title="Detailed Post-processing Analysis")

## 🔬 Practical Applications and Case Studies

### Real-world Post-processing Scenarios

This section demonstrates how to apply the post-processing pipeline to real cardiac segmentation challenges:

1. **Clinical Workflow Integration**: Automating post-processing in clinical environments
2. **Multi-modal Cardiac Imaging**: Handling different MRI sequences (T1, T2, CINE)
3. **Pathological Cases**: Processing segmentations with cardiac abnormalities
4. **Quality Control**: Implementing automated quality checks for clinical use
5. **Performance Optimization**: Balancing accuracy vs processing speed

### Key Learning Points

- Understanding when and how to apply different post-processing techniques
- Adapting parameters for different cardiac conditions and imaging protocols
- Implementing robust error handling and fallback mechanisms
- Creating efficient batch processing workflows
- Validating results against clinical standards

---

In [None]:
class ClinicalPostProcessor:
    """Clinical-grade post-processing with condition-specific adaptations"""
    
    def __init__(self):
        self.condition_configs = self._setup_condition_configs()
        self.quality_thresholds = self._setup_quality_thresholds()
        
    def _setup_condition_configs(self) -> Dict[str, PostProcessConfig]:
        """Setup configurations for different cardiac conditions"""
        configs = {}
        
        # Normal heart
        configs['normal'] = PostProcessConfig(
            min_component_size=200,
            min_heart_area=2000,
            max_heart_area=15000,
            circularity_threshold=0.4
        )
        
        # Dilated cardiomyopathy (enlarged heart)
        configs['dilated'] = PostProcessConfig(
            min_component_size=300,
            min_heart_area=4000,
            max_heart_area=25000,
            circularity_threshold=0.3,
            aspect_ratio_range=(0.6, 2.5)
        )
        
        # Hypertrophic cardiomyopathy (thickened walls)
        configs['hypertrophic'] = PostProcessConfig(
            min_component_size=150,
            min_heart_area=1500,
            max_heart_area=20000,
            circularity_threshold=0.5,
            closing_iterations=3  # More aggressive closing for thick walls
        )
        
        # Post-infarction (irregular shape)
        configs['post_infarction'] = PostProcessConfig(
            min_component_size=100,
            min_heart_area=1000,
            max_heart_area=18000,
            circularity_threshold=0.2,
            smooth_iterations=5,  # More smoothing for irregular boundaries
            gaussian_sigma=1.5
        )
        
        # Pediatric cases (smaller hearts)
        configs['pediatric'] = PostProcessConfig(
            min_component_size=50,
            min_heart_area=500,
            max_heart_area=8000,
            circularity_threshold=0.4,
            kernel_size=2  # Smaller kernels for smaller structures
        )
        
        return configs
    
    def _setup_quality_thresholds(self) -> Dict[str, Dict[str, float]]:
        """Setup quality thresholds for clinical acceptance"""
        return {
            'normal': {
                'min_dice': 0.85,
                'max_hausdorff': 15.0,
                'min_smoothness': 0.5,
                'max_components': 3
            },
            'dilated': {
                'min_dice': 0.80,
                'max_hausdorff': 20.0,
                'min_smoothness': 0.4,
                'max_components': 3
            },
            'hypertrophic': {
                'min_dice': 0.82,
                'max_hausdorff': 18.0,
                'min_smoothness': 0.6,
                'max_components': 2
            },
            'post_infarction': {
                'min_dice': 0.75,
                'max_hausdorff': 25.0,
                'min_smoothness': 0.3,
                'max_components': 4
            },
            'pediatric': {
                'min_dice': 0.88,
                'max_hausdorff': 10.0,
                'min_smoothness': 0.6,
                'max_components': 2
            }
        }
    
    def process_clinical_case(self, 
                            mask: np.ndarray,
                            condition: str = 'normal',
                            patient_metadata: Dict = None) -> Tuple[np.ndarray, Dict]:
        """Process a clinical case with condition-specific parameters"""
        
        if condition not in self.condition_configs:
            print(f"Warning: Unknown condition '{condition}', using 'normal' config")
            condition = 'normal'
        
        # Get condition-specific configuration
        config = self.condition_configs[condition]
        
        # Create pipeline with condition-specific config
        pipeline = PostProcessingPipeline(config)
        
        # Process mask
        processed_mask, metrics = pipeline.process_single_mask(mask)
        
        # Perform quality assessment
        quality_assessor = QualityAssessment()
        quality_results = quality_assessor.evaluate_post_processing(mask, processed_mask)
        
        # Check against clinical thresholds
        quality_check = self._validate_clinical_quality(
            quality_results, condition, processed_mask
        )
        
        # Compile comprehensive results
        results = {
            'processed_mask': processed_mask,
            'processing_metrics': metrics,
            'quality_assessment': quality_results,
            'clinical_validation': quality_check,
            'condition': condition,
            'config_used': config.__dict__,
            'patient_metadata': patient_metadata or {}
        }
        
        return processed_mask, results
    
    def _validate_clinical_quality(self, 
                                 quality_results: Dict, 
                                 condition: str,
                                 processed_mask: np.ndarray) -> Dict:
        """Validate results against clinical quality thresholds"""
        
        thresholds = self.quality_thresholds[condition]
        validation_results = {
            'passed': True,
            'warnings': [],
            'errors': [],
            'recommendations': []
        }
        
        # Check smoothness
        if 'processed_boundary_smoothness' in quality_results:
            smoothness = quality_results['processed_boundary_smoothness']
            if smoothness < thresholds['min_smoothness']:
                validation_results['warnings'].append(
                    f"Low boundary smoothness: {smoothness:.3f} < {thresholds['min_smoothness']}"
                )
                validation_results['recommendations'].append("Consider additional smoothing")
        
        # Check number of components
        labels = measure.label(processed_mask)
        n_components = len(measure.regionprops(labels))
        if n_components > thresholds['max_components']:
            validation_results['errors'].append(
                f"Too many components: {n_components} > {thresholds['max_components']}"
            )
            validation_results['recommendations'].append("Apply more aggressive component filtering")
            validation_results['passed'] = False
        
        # Overall validation
        if validation_results['errors']:
            validation_results['passed'] = False
        
        return validation_results
    
    def batch_clinical_processing(self, 
                                cases: List[Dict]) -> List[Dict]:
        """Process multiple clinical cases in batch"""
        
        results = []
        
        for i, case in enumerate(tqdm(cases, desc="Processing clinical cases")):
            try:
                mask = case['mask']
                condition = case.get('condition', 'normal')
                metadata = case.get('metadata', {})
                
                processed_mask, case_results = self.process_clinical_case(
                    mask, condition, metadata
                )
                
                case_results['case_id'] = i
                case_results['success'] = True
                results.append(case_results)
                
            except Exception as e:
                print(f"Error processing case {i}: {e}")
                results.append({
                    'case_id': i,
                    'success': False,
                    'error': str(e),
                    'processed_mask': case['mask']  # Return original on error
                })
        
        return results
    
    def generate_clinical_report(self, results: List[Dict]) -> str:
        """Generate clinical report from processing results"""
        
        report = []
        report.append("# CARDIAC SEGMENTATION POST-PROCESSING REPORT")
        report.append("=" * 50)
        report.append("")
        
        # Summary statistics
        total_cases = len(results)
        successful_cases = sum(1 for r in results if r.get('success', False))
        passed_quality = sum(1 for r in results 
                           if r.get('success', False) and 
                           r.get('clinical_validation', {}).get('passed', False))
        
        report.append(f"Total cases processed: {total_cases}")
        report.append(f"Successful processing: {successful_cases} ({successful_cases/total_cases*100:.1f}%)")
        report.append(f"Passed clinical quality: {passed_quality} ({passed_quality/total_cases*100:.1f}%)")
        report.append("")
        
        # Condition breakdown
        conditions = {}
        for result in results:
            if result.get('success', False):
                condition = result.get('condition', 'unknown')
                if condition not in conditions:
                    conditions[condition] = {'total': 0, 'passed': 0}
                conditions[condition]['total'] += 1
                if result.get('clinical_validation', {}).get('passed', False):
                    conditions[condition]['passed'] += 1
        
        report.append("## Condition Breakdown:")
        for condition, stats in conditions.items():
            pass_rate = stats['passed'] / stats['total'] * 100 if stats['total'] > 0 else 0
            report.append(f"  {condition.title()}: {stats['passed']}/{stats['total']} ({pass_rate:.1f}% pass rate)")
        report.append("")
        
        # Quality metrics summary
        quality_metrics = defaultdict(list)
        for result in results:
            if result.get('success', False) and 'quality_assessment' in result:
                qa = result['quality_assessment']
                for metric, value in qa.items():
                    if isinstance(value, (int, float)) and not np.isnan(value) and not np.isinf(value):
                        quality_metrics[metric].append(value)
        
        if quality_metrics:
            report.append("## Quality Metrics Summary:")
            for metric, values in quality_metrics.items():
                if values:
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    report.append(f"  {metric}: {mean_val:.3f} ± {std_val:.3f}")
            report.append("")
        
        # Recommendations
        all_recommendations = []
        for result in results:
            if result.get('success', False):
                recs = result.get('clinical_validation', {}).get('recommendations', [])
                all_recommendations.extend(recs)
        
        if all_recommendations:
            unique_recs = list(set(all_recommendations))
            report.append("## Clinical Recommendations:")
            for rec in unique_recs:
                count = all_recommendations.count(rec)
                report.append(f"  • {rec} (mentioned in {count} cases)")
        
        return "\\n".join(report)

# Create synthetic clinical cases for demonstration
def create_synthetic_clinical_cases(n_cases: int = 10) -> List[Dict]:
    """Create synthetic clinical cases for testing"""
    cases = []
    conditions = ['normal', 'dilated', 'hypertrophic', 'post_infarction', 'pediatric']
    
    for i in range(n_cases):
        condition = np.random.choice(conditions)
        
        # Create condition-specific mask
        if condition == 'normal':
            mask = create_normal_heart_mask()
        elif condition == 'dilated':
            mask = create_dilated_heart_mask()
        elif condition == 'hypertrophic':
            mask = create_hypertrophic_heart_mask()
        elif condition == 'post_infarction':
            mask = create_post_infarction_mask()
        else:  # pediatric
            mask = create_pediatric_heart_mask()
        
        cases.append({
            'mask': mask,
            'condition': condition,
            'metadata': {
                'patient_id': f'P{i+1:03d}',
                'age': np.random.randint(20, 80) if condition != 'pediatric' else np.random.randint(1, 18),
                'sex': np.random.choice(['M', 'F']),
                'scan_date': f'2024-01-{np.random.randint(1, 30):02d}'
            }
        })
    
    return cases

def create_normal_heart_mask() -> np.ndarray:
    """Create a normal heart mask"""
    mask = np.zeros((256, 256), dtype=np.uint8)
    cv2.ellipse(mask, (128, 128), (60, 45), 0, 0, 360, 1, -1)
    # Add some noise
    noise = np.random.random(mask.shape) > 0.98
    return mask | noise.astype(np.uint8)

def create_dilated_heart_mask() -> np.ndarray:
    """Create a dilated cardiomyopathy mask (enlarged)"""
    mask = np.zeros((256, 256), dtype=np.uint8)
    cv2.ellipse(mask, (128, 128), (85, 70), 0, 0, 360, 1, -1)
    # Add irregularities
    noise = np.random.random(mask.shape) > 0.95
    return mask | noise.astype(np.uint8)

def create_hypertrophic_heart_mask() -> np.ndarray:
    """Create a hypertrophic cardiomyopathy mask (thick walls)"""
    mask = np.zeros((256, 256), dtype=np.uint8)
    # Outer boundary
    cv2.ellipse(mask, (128, 128), (55, 50), 0, 0, 360, 1, -1)
    # Inner cavity (smaller)
    inner_mask = np.zeros((256, 256), dtype=np.uint8)
    cv2.ellipse(inner_mask, (128, 128), (25, 20), 0, 0, 360, 1, -1)
    mask = mask & ~inner_mask
    return mask

def create_post_infarction_mask() -> np.ndarray:
    """Create a post-infarction mask (irregular shape)"""
    mask = np.zeros((256, 256), dtype=np.uint8)
    # Create irregular shape using multiple ellipses
    cv2.ellipse(mask, (128, 128), (50, 40), 0, 0, 180, 1, -1)
    cv2.ellipse(mask, (138, 138), (40, 35), 45, 0, 270, 1, -1)
    # Add significant noise
    noise = np.random.random(mask.shape) > 0.92
    return mask | noise.astype(np.uint8)

def create_pediatric_heart_mask() -> np.ndarray:
    """Create a pediatric heart mask (smaller, more circular)"""
    mask = np.zeros((256, 256), dtype=np.uint8)
    cv2.circle(mask, (128, 128), 35, 1, -1)
    # Minimal noise for cleaner pediatric images
    noise = np.random.random(mask.shape) > 0.99
    return mask | noise.astype(np.uint8)

# Test the clinical post-processor
clinical_processor = ClinicalPostProcessor()
print("✅ Clinical Post-processor initialized!")
print(f"Supported conditions: {list(clinical_processor.condition_configs.keys())}")

# Create and process synthetic clinical cases
test_cases = create_synthetic_clinical_cases(6)
print(f"\\nCreated {len(test_cases)} synthetic clinical cases")

# Process cases
results = clinical_processor.batch_clinical_processing(test_cases)

# Generate clinical report
report = clinical_processor.generate_clinical_report(results)
print("\\n" + "="*60)
print("CLINICAL PROCESSING REPORT")
print("="*60)
print(report)

# Visualize some results
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for i, (case, result) in enumerate(zip(test_cases[:6], results[:6])):
    if result.get('success', False):
        original = case['mask']
        processed = result['processed_mask']
        condition = case['condition']
        
        # Show overlay
        overlay = np.zeros((*original.shape, 3))
        overlay[:, :, 0] = original  # Red channel for original
        overlay[:, :, 1] = processed  # Green channel for processed
        
        axes[i].imshow(overlay)
        axes[i].set_title(f"{condition.title()}\\nPatient: {case['metadata']['patient_id']}")
        axes[i].axis('off')
        
        # Add quality indicator
        passed = result.get('clinical_validation', {}).get('passed', False)
        color = 'green' if passed else 'red'
        axes[i].add_patch(plt.Rectangle((5, 5), 20, 20, facecolor=color, alpha=0.7))

plt.suptitle('Clinical Post-processing Results\\n(Green overlay: processed, Red: original, Corner: quality indicator)', 
             fontsize=14)
plt.tight_layout()
plt.show()

## 📋 Summary and Best Practices

### Key Achievements in This Notebook

1. **Comprehensive Morphological Processing**: Implemented advanced morphological operations with adaptive kernels
2. **Intelligent Component Analysis**: Created sophisticated connected component filtering with anatomical awareness
3. **Anatomical Constraint Validation**: Developed condition-specific validation for different cardiac pathologies
4. **Advanced Boundary Refinement**: Implemented multiple smoothing techniques preserving important anatomical features
5. **Clinical Integration**: Created a complete clinical-grade pipeline with quality assessment and reporting

### Best Practices for Cardiac Segmentation Post-processing

#### 🎯 **Parameter Selection**
- **Condition-Specific**: Adapt parameters based on cardiac condition (normal, dilated, hypertrophic, etc.)
- **Image Resolution**: Scale morphological kernel sizes based on image resolution
- **Patient Demographics**: Adjust area thresholds for pediatric vs adult cases
- **Imaging Protocol**: Modify smoothing parameters based on MRI sequence type

#### 🔧 **Pipeline Design**
- **Modular Architecture**: Keep each processing step independent and configurable
- **Error Handling**: Implement robust fallback mechanisms for edge cases
- **Quality Gates**: Add validation checkpoints throughout the pipeline
- **Performance Monitoring**: Track processing times and resource usage

#### 🏥 **Clinical Considerations**
- **Validation Requirements**: Ensure all outputs meet clinical quality standards
- **Traceability**: Maintain detailed logs of all processing steps and parameters
- **User Feedback**: Provide clear quality indicators and recommendations
- **Integration**: Design for seamless integration with existing clinical workflows

### Performance Optimization Tips

```python
# Example optimization strategies:

# 1. Parallel processing for batch operations
from multiprocessing import Pool
def parallel_processing(masks_batch):
    with Pool() as pool:
        return pool.map(pipeline.process_single_mask, masks_batch)

# 2. Memory-efficient processing for large images
def process_large_image(large_mask, tile_size=512):
    # Process in overlapping tiles
    # Merge results with overlap handling
    pass

# 3. GPU acceleration for morphological operations
import cupy as cp  # GPU arrays
def gpu_morphology(mask):
    gpu_mask = cp.asarray(mask)
    # Perform GPU-accelerated operations
    return cp.asnumpy(gpu_mask)
```

### Common Pitfalls and Solutions

#### ❌ **Common Mistakes**
1. **Over-smoothing**: Losing important anatomical details
2. **Aggressive filtering**: Removing small but important structures
3. **Ignoring pathology**: Using normal parameters for pathological cases
4. **No validation**: Skipping quality assessment steps

#### ✅ **Solutions**
1. **Adaptive smoothing**: Use image gradients to preserve edges
2. **Multi-scale analysis**: Process at different scales and combine results
3. **Condition awareness**: Implement pathology-specific processing paths
4. **Continuous validation**: Monitor quality metrics throughout processing

### Integration with the Complete Pipeline

This post-processing notebook integrates seamlessly with the other components:

- **Input**: Raw segmentation masks from notebook 06 (Model Evaluation)
- **Processing**: Advanced morphological and anatomical refinement
- **Output**: Clinical-ready segmentations for notebook 08 (Final Inference)
- **Quality Control**: Comprehensive metrics for clinical validation

### Next Steps

1. **Advanced Techniques**: Implement deep learning-based post-processing
2. **Real-time Processing**: Optimize for real-time clinical applications
3. **Multi-modal Integration**: Combine with other imaging modalities
4. **Automated Parameter Tuning**: Develop adaptive parameter selection
5. **Clinical Validation**: Extensive testing with real clinical data

---

## 🚀 Ready for Final Inference

The post-processed segmentations are now ready for the final inference pipeline in **notebook 08**, where we'll implement:

- End-to-end inference workflows
- Batch processing optimization
- Performance benchmarking
- Clinical deployment preparation
- Results compilation and reporting

The robust post-processing pipeline ensures that all segmentations meet clinical quality standards before final deployment! 🏥✨