# Deep Learning Model Inference Pipeline

This notebook implements the inference pipeline for optic nerve segmentation using a pre-trained 2D U-Net model.

## Pipeline Overview
1. **Model Architecture** - Load and configure 2D U-Net segmentation model
2. **Data Loading** - Prepare normalized patches for inference
3. **Inference Pipeline** - Batch processing with performance monitoring
4. **Results Processing** - Save predictions and generate visualizations

## Requirements
- Pre-trained model weights (model.pth)
- Normalized patches from preprocessing pipeline
- PyTorch, MONAI for model architecture
- CUDA support (optional, CPU fallback available)

In [None]:
# Core imports
import os
import sys
from pathlib import Path
import logging
from typing import Dict, List, Optional
import time

# Scientific computing and visualization
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Deep learning
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Import model and dataset functions
from functions import (
    SegmentationDataset,
    create_monai_unet,
    evaluate_model,
    visualize_segmentation_predictions,
    predict_and_save_masks
)

# Configure matplotlib for inline display
%matplotlib inline

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("Deep Learning Model Inference Pipeline - Initialized")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2.1 Configuration & Setup

Configure inference parameters and model settings.

In [None]:
# =============================================================================
# INFERENCE CONFIGURATION - Modify these according to your setup
# =============================================================================

# Data paths
DATA_ROOT = Path("data")
PROCESSED_DATA_PATH = DATA_ROOT / "processed"
NORMALIZED_PATCHES_PATH = PROCESSED_DATA_PATH / "patches_normalized"

# Model configuration
MODEL_WEIGHTS_PATH = "model.pth"  # Path to pre-trained model weights
MODEL_CONFIG = {
    'spatial_dims': 2,
    'in_channels': 1,
    'out_channels': 1,
    'channels': (16, 32, 64, 128),
    'strides': (2, 2, 2),
    'dropout': 0.2
}

# Inference parameters
INFERENCE_CONFIG = {
    'batch_size': 4,
    'threshold': 0.5,  # Sigmoid threshold for binary predictions
    'num_workers': 0,  # DataLoader workers (set to 0 for Windows compatibility)
    'pin_memory': False  # Will be set based on CUDA availability
}

# Expected data structure
INPUT_SIDES = ['left', 'right']
SAMPLE_PATIENTS = []  # Will be auto-detected from data

# Output configuration
OUTPUT_DIRS = {
    'predictions': PROCESSED_DATA_PATH / 'predictions',
    'visualizations': PROCESSED_DATA_PATH / 'visualizations'
}

# Create output directories
for path in OUTPUT_DIRS.values():
    path.mkdir(parents=True, exist_ok=True)

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INFERENCE_CONFIG['pin_memory'] = torch.cuda.is_available()

print("Inference configuration loaded:")
print(f"  Device: {DEVICE}")
print(f"  Model weights: {MODEL_WEIGHTS_PATH}")
print(f"  Batch size: {INFERENCE_CONFIG['batch_size']}")
print(f"  Threshold: {INFERENCE_CONFIG['threshold']}")
print(f"  Input data path: {NORMALIZED_PATCHES_PATH}")
print(f"  Output directories:")
for name, path in OUTPUT_DIRS.items():
    print(f"    {name}: {path}")

### Environment & Data Validation

Verify model weights and input data availability.

In [None]:
def validate_inference_setup() -> Dict[str, List[str]]:
    """Validate inference environment and data availability."""
    
    # Check model weights
    if not Path(MODEL_WEIGHTS_PATH).exists():
        raise FileNotFoundError(f"Model weights not found: {MODEL_WEIGHTS_PATH}")
    else:
        model_size = Path(MODEL_WEIGHTS_PATH).stat().st_size / (1024 * 1024)  # MB
        print(f"✓ Model weights found: {MODEL_WEIGHTS_PATH} ({model_size:.1f} MB)")
    
    # Check input data structure
    if not NORMALIZED_PATCHES_PATH.exists():
        raise FileNotFoundError(f"Normalized patches directory not found: {NORMALIZED_PATCHES_PATH}")
    
    # Discover available data
    available_data = {}
    total_patches = 0
    
    for side in INPUT_SIDES:
        side_path = NORMALIZED_PATCHES_PATH / side
        if not side_path.exists():
            logger.warning(f"Side directory not found: {side_path}")
            available_data[side] = []
            continue
        
        # Find patient directories
        patient_dirs = [d.name for d in side_path.iterdir() if d.is_dir()]
        available_data[side] = patient_dirs
        
        print(f"✓ Found {len(patient_dirs)} patients in {side}/ directory")
        
        # Count patches per patient
        for patient in patient_dirs[:3]:  # Show first 3 patients
            patient_path = side_path / patient
            patch_count = len(list(patient_path.glob('*.nii.gz')))
            total_patches += patch_count
            print(f"    {patient}: {patch_count} patches")
        
        if len(patient_dirs) > 3:
            # Count remaining patches
            for patient in patient_dirs[3:]:
                patient_path = side_path / patient
                patch_count = len(list(patient_path.glob('*.nii.gz')))
                total_patches += patch_count
            print(f"    ... and {len(patient_dirs) - 3} more patients")
    
    print(f"\nTotal patches available for inference: {total_patches}")
    
    return available_data

# Run validation
available_data = validate_inference_setup()

# Update sample patients list
all_patients = set()
for patients in available_data.values():
    all_patients.update(patients)
SAMPLE_PATIENTS = sorted(list(all_patients))

print(f"Detected patients: {SAMPLE_PATIENTS}")

## 2.2 Model Architecture

Load and configure the 2D U-Net segmentation model.

In [None]:
# Create and load model
def initialize_model() -> nn.Module:
    """Initialize the U-Net model with pre-trained weights."""
    
    print("Creating U-Net model architecture...")
    model = create_monai_unet(**MODEL_CONFIG)
    
    # Display model architecture summary
    def count_parameters(model):
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return total_params, trainable_params
    
    total_params, trainable_params = count_parameters(model)
    
    print(f"Model Architecture Summary:")
    print(f"  Type: 2D U-Net (MONAI)")
    print(f"  Input channels: {MODEL_CONFIG['in_channels']}")
    print(f"  Output channels: {MODEL_CONFIG['out_channels']}")
    print(f"  Feature channels: {MODEL_CONFIG['channels']}")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
    return model

# Initialize model
model = initialize_model()

# Display detailed architecture (first few layers)
print("\nModel Architecture Details:")
print(model)

print(f"\nModel successfully initialized and ready for inference.")

## 2.3 Data Loading

Prepare datasets and data loaders for efficient batch processing.

In [None]:
# Prepare dataset paths
def prepare_inference_data() -> Tuple[List[str], SegmentationDataset]:
    """Prepare dataset paths and create dataset object."""
    
    inference_dirs = []
    
    # Collect all patient directories
    for side in INPUT_SIDES:
        for patient in available_data.get(side, []):
            patient_path = NORMALIZED_PATCHES_PATH / side / patient
            if patient_path.exists():
                inference_dirs.append(str(patient_path))
    
    print(f"Prepared {len(inference_dirs)} dataset directories:")
    for i, dir_path in enumerate(inference_dirs):
        patch_count = len(list(Path(dir_path).glob('*.nii.gz')))
        dir_name = Path(dir_path).parts[-2:]  # Get last 2 parts (side/patient)
        print(f"  {i+1}. {'/'.join(dir_name)}: {patch_count} patches")
    
    # Create dataset
    print("\nCreating inference dataset...")
    dataset = SegmentationDataset(inference_dirs)
    
    print(f"Dataset created successfully:")
    print(f"  Total patches: {len(dataset)}")
    print(f"  Directories: {len(inference_dirs)}")
    
    return inference_dirs, dataset

# Prepare inference data
inference_dirs, inference_dataset = prepare_inference_data()

# Create data loader
print("\nCreating data loader...")
inference_loader = DataLoader(
    inference_dataset,
    batch_size=INFERENCE_CONFIG['batch_size'],
    shuffle=False,  # Maintain order for result mapping
    num_workers=INFERENCE_CONFIG['num_workers'],
    pin_memory=INFERENCE_CONFIG['pin_memory']
)

print(f"Data loader configuration:")
print(f"  Batch size: {INFERENCE_CONFIG['batch_size']}")
print(f"  Total batches: {len(inference_loader)}")
print(f"  Pin memory: {INFERENCE_CONFIG['pin_memory']}")
print(f"  Workers: {INFERENCE_CONFIG['num_workers']}")

### Sample Data Visualization

Visualize sample patches to verify data loading.

In [None]:
# Visualize sample patches from dataset
def visualize_sample_patches(dataset, num_samples: int = 4) -> None:
    """Visualize sample patches from the dataset."""
    
    if len(dataset) == 0:
        print("No data available for visualization")
        return
    
    # Select sample indices
    indices = np.linspace(0, len(dataset) - 1, num_samples, dtype=int)
    
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 4))
    if num_samples == 1:
        axes = [axes]
    
    fig.suptitle('Sample Input Patches for Inference', fontsize=16)
    
    for i, idx in enumerate(indices):
        try:
            # Get image from dataset
            image = dataset[idx]
            
            # Convert to numpy for visualization
            if isinstance(image, torch.Tensor):
                image_np = image.squeeze().cpu().numpy()
            else:
                image_np = image.squeeze()
            
            # Display image
            axes[i].imshow(image_np, cmap='gray')
            axes[i].set_title(f'Patch {idx}\nShape: {image_np.shape}')
            axes[i].axis('off')
            
            # Display intensity statistics
            min_val, max_val = image_np.min(), image_np.max()
            mean_val = image_np.mean()
            axes[i].text(0.02, 0.98, f'Range: [{min_val:.3f}, {max_val:.3f}]\nMean: {mean_val:.3f}', 
                        transform=axes[i].transAxes, fontsize=8, 
                        verticalalignment='top', color='white',
                        bbox=dict(boxstyle='round', facecolor='black', alpha=0.5))
            
        except Exception as e:
            axes[i].text(0.5, 0.5, f'Error loading\nindex {idx}', 
                        ha='center', va='center', transform=axes[i].transAxes)
            axes[i].axis('off')
            logger.error(f"Error visualizing sample {idx}: {e}")
    
    plt.tight_layout()
    plt.show()

# Visualize samples
if len(inference_dataset) > 0:
    visualize_sample_patches(inference_dataset, min(4, len(inference_dataset)))
else:
    print("No patches available for visualization")

## 2.4 Model Inference

Run inference on all patches with performance monitoring.

In [None]:
# Run model evaluation
def run_inference_evaluation() -> Dict:
    """Run model inference and return evaluation results."""
    
    print("Starting model inference...")
    print(f"Device: {DEVICE}")
    print(f"Model weights: {MODEL_WEIGHTS_PATH}")
    print(f"Threshold: {INFERENCE_CONFIG['threshold']}")
    
    start_time = time.time()
    
    # Run evaluation (inference only, no ground truth masks)
    results = evaluate_model(
        model=model,
        test_loader=inference_loader,
        model_weights_path=MODEL_WEIGHTS_PATH,
        threshold=INFERENCE_CONFIG['threshold'],
        device=DEVICE
    )
    
    end_time = time.time()
    inference_time = end_time - start_time
    
    # Calculate performance metrics
    total_patches = len(inference_dataset)
    patches_per_second = total_patches / inference_time if inference_time > 0 else 0
    time_per_batch = inference_time / results['batches'] if results['batches'] > 0 else 0
    
    results.update({
        'total_patches': total_patches,
        'inference_time': inference_time,
        'patches_per_second': patches_per_second,
        'time_per_batch': time_per_batch
    })
    
    return results

# Run inference
inference_results = run_inference_evaluation()

# Display results
print("\nInference Results:")
print(f"  Total patches processed: {inference_results['total_patches']}")
print(f"  Total batches: {inference_results['batches']}")
print(f"  Inference time: {inference_results['inference_time']:.2f} seconds")
print(f"  Processing speed: {inference_results['patches_per_second']:.2f} patches/second")
print(f"  Average time per batch: {inference_results['time_per_batch']:.4f} seconds")

# Performance assessment
if inference_results['patches_per_second'] > 50:
    print("  Performance: ✓ Excellent processing speed")
elif inference_results['patches_per_second'] > 20:
    print("  Performance: ✓ Good processing speed")
else:
    print("  Performance: ⚠ Consider optimizing batch size or using GPU")

### Prediction Visualization

Visualize model predictions on sample patches.

In [None]:
# Visualize predictions
print("Generating prediction visualizations...")

try:
    # Generate visualization with sample predictions
    visualize_segmentation_predictions(
        model=model,
        test_dataset=inference_dataset,
        num_images=4
    )
    
    print("Prediction visualization completed successfully.")
    
except Exception as e:
    logger.error(f"Visualization failed: {e}")
    print(f"Error during visualization: {e}")
    print("Continuing with prediction saving...")

## 2.5 Save Predictions

Save all predictions as NIfTI files with structured organization.

In [None]:
# Save all predictions
def save_inference_results() -> List[str]:
    """Save all model predictions as NIfTI files."""
    
    print("Saving prediction results...")
    print(f"Output directory: {OUTPUT_DIRS['predictions']}")
    
    start_time = time.time()
    
    # Save predictions
    saved_paths = predict_and_save_masks(
        model=model,
        test_dataset=inference_dataset,
        test_image_dirs=inference_dirs,
        output_base_dir=str(OUTPUT_DIRS['predictions']),
        threshold=INFERENCE_CONFIG['threshold'],
        device=DEVICE
    )
    
    end_time = time.time()
    save_time = end_time - start_time
    
    print(f"\nPrediction saving completed:")
    print(f"  Total predictions saved: {len(saved_paths)}")
    print(f"  Save time: {save_time:.2f} seconds")
    print(f"  Average time per prediction: {save_time/len(saved_paths):.4f} seconds")
    
    return saved_paths

# Save predictions
saved_prediction_paths = save_inference_results()

# Verify saved structure
def verify_prediction_structure(output_dir: Path) -> Dict[str, int]:
    """Verify the structure of saved predictions."""
    structure_info = {}
    
    print("\nSaved prediction structure:")
    
    for side in INPUT_SIDES:
        side_path = output_dir / side
        if side_path.exists():
            patients = [d.name for d in side_path.iterdir() if d.is_dir()]
            total_predictions = 0
            
            print(f"  {side}/: {len(patients)} patients")
            
            for patient in sorted(patients):
                patient_path = side_path / patient
                pred_count = len(list(patient_path.glob('*.nii.gz')))
                total_predictions += pred_count
                print(f"    {patient}: {pred_count} predictions")
            
            structure_info[side] = total_predictions
        else:
            print(f"  {side}/: No predictions found")
            structure_info[side] = 0
    
    return structure_info

# Verify structure
prediction_structure = verify_prediction_structure(OUTPUT_DIRS['predictions'])
total_saved = sum(prediction_structure.values())

print(f"\nTotal predictions saved across all sides: {total_saved}")

# Validation check
if total_saved == len(saved_prediction_paths):
    print("✓ Prediction count validation passed")
else:
    print(f"⚠ Warning: Mismatch in prediction counts ({total_saved} vs {len(saved_prediction_paths)})")

## Pipeline Summary

Complete inference pipeline execution summary and performance metrics.

In [None]:
# Generate comprehensive pipeline summary
def generate_inference_summary() -> Dict:
    """Generate and display complete inference pipeline summary."""
    
    summary = {
        'model_config': MODEL_CONFIG,
        'inference_config': INFERENCE_CONFIG,
        'device': str(DEVICE),
        'data_summary': {
            'total_patients': len(SAMPLE_PATIENTS),
            'total_patches': len(inference_dataset),
            'total_directories': len(inference_dirs)
        },
        'inference_results': inference_results,
        'output_summary': {
            'total_predictions': len(saved_prediction_paths),
            'predictions_by_side': prediction_structure,
            'output_directory': str(OUTPUT_DIRS['predictions'])
        }
    }
    
    print("=" * 60)
    print("DEEP LEARNING INFERENCE PIPELINE - SUMMARY")
    print("=" * 60)
    
    # Model summary
    print(f"Model Configuration:")
    print(f"  Architecture: 2D U-Net (MONAI)")
    print(f"  Weights: {MODEL_WEIGHTS_PATH}")
    print(f"  Device: {DEVICE}")
    print(f"  Input/Output channels: {MODEL_CONFIG['in_channels']} -> {MODEL_CONFIG['out_channels']}")
    
    # Data summary
    print(f"\nData Processing:")
    print(f"  Patients processed: {len(SAMPLE_PATIENTS)}")
    print(f"  Total patches: {len(inference_dataset)}")
    print(f"  Batch size: {INFERENCE_CONFIG['batch_size']}")
    print(f"  Total batches: {inference_results['batches']}")
    
    # Performance summary
    print(f"\nPerformance Metrics:")
    print(f"  Inference time: {inference_results['inference_time']:.2f} seconds")
    print(f"  Processing speed: {inference_results['patches_per_second']:.2f} patches/second")
    print(f"  Average batch time: {inference_results['time_per_batch']:.4f} seconds")
    
    # Output summary
    print(f"\nOutput Results:")
    print(f"  Total predictions saved: {len(saved_prediction_paths)}")
    print(f"  Output directory: {OUTPUT_DIRS['predictions']}")
    print(f"  Predictions by side:")
    for side, count in prediction_structure.items():
        print(f"    {side}: {count} predictions")
    
    # Quality checks
    print(f"\nQuality Checks:")
    checks_passed = 0
    total_checks = 3
    
    if len(saved_prediction_paths) > 0:
        print(f"  ✓ Predictions generated successfully")
        checks_passed += 1
    else:
        print(f"  ✗ No predictions generated")
    
    if len(saved_prediction_paths) == len(inference_dataset):
        print(f"  ✓ All input patches processed")
        checks_passed += 1
    else:
        print(f"  ⚠ Patch count mismatch: {len(saved_prediction_paths)} vs {len(inference_dataset)}")
    
    if inference_results['patches_per_second'] > 10:
        print(f"  ✓ Acceptable processing speed")
        checks_passed += 1
    else:
        print(f"  ⚠ Processing speed below optimal")
    
    # Final status
    if checks_passed == total_checks:
        status = "COMPLETED SUCCESSFULLY"
    elif checks_passed >= total_checks - 1:
        status = "COMPLETED WITH WARNINGS"
    else:
        status = "COMPLETED WITH ISSUES"
    
    print(f"\nPipeline Status: {status} ({checks_passed}/{total_checks} checks passed)")
    print("=" * 60)
    
    return summary

# Generate summary
pipeline_summary = generate_inference_summary()

print(f"\nInference pipeline completed successfully!")
print(f"Predictions are ready for 3D reconstruction in the next notebook.")
print(f"\nNext steps:")
print(f"  1. Run 03_reconstruction_analysis.ipynb for 3D volume reconstruction")
print(f"  2. Analyze segmentation quality and results")
print(f"  3. Export final results for clinical analysis")