# Renal MRI Segmentation Inference Demo

This notebook demonstrates how to load a pre-trained segmentation model and perform inference on new NIfTI files.

## Overview
1. Load a pre-trained model (with or without pre-training)
2. Load a NIfTI image and its corresponding mask (optional, for evaluation)
3. Preprocess the image (normalization, resizing)
4. Run inference to get segmentation
5. Visualize and save results
6. Calculate evaluation metrics (if ground truth mask is provided)

In [None]:
# =============================================================================
# 1. Setup and Imports
# =============================================================================

import os
import sys
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from pathlib import Path
import tensorflow as tf

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Import project modules
from models.model_seg_gn import modelObj
from data.data_utils import normalize, crop_or_pad_3d

# Set up plotting style
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## 2. Configuration

Set the paths to your model and data files.

In [None]:
# =============================================================================
# 2. Configuration
# =============================================================================

# Model configuration
model_config = {
    'img_size_x': 256,
    'img_size_y': 256,
    'num_channels': 1,
    'latent_dim': 64,
    'num_classes': 3,  # 3 for whole kidney, 5 for cortex/medulla
}

# Path to pre-trained model weights
model_weights_path = "/path/to/your//weights.hdf5"

# Path to input NIfTI image
image_path = "/path/to/your/image.nii"

# Optional: Path to ground truth mask for evaluation
mask_path = "/path/to/your/mask.nii.gz"

# Output directory for results
output_dir = "./inference_results"
os.makedirs(output_dir, exist_ok=True)

print(f"Model weights: {model_weights_path}")
print(f"Input image: {image_path}")
print(f"Output directory: {output_dir}")

## 3. Load Model

Load the pre-trained segmentation model.

In [None]:
# =============================================================================
# 3. Load Model
# =============================================================================

class Config:
    """Simple config class to hold model parameters."""
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            setattr(self, key, value)

# Create config object
cfg = Config(model_config)

# Initialize model utility
print("Initializing model...")
mm_utils = modelObj(cfg)

# Create segmentation model
model = mm_utils.seg_unet(num_classes=cfg.num_classes)

# Load weights
if os.path.exists(model_weights_path):
    print(f"Loading weights from: {model_weights_path}")
    model.load_weights(model_weights_path)
    print("✓ Model loaded successfully")
else:
    print(f"✗ Model weights not found at: {model_weights_path}")
    
# Print model summary
model.summary()

## 4. Load and Preprocess Image

Load the NIfTI image and preprocess it for inference.

In [None]:
# =============================================================================
# 4. Load and Preprocess Image
# =============================================================================

def load_and_preprocess_image(image_path, target_size=(256, 256)):
    """
    Load NIfTI image and preprocess for inference.
    
    Args:
        image_path: Path to NIfTI file
        target_size: Target (height, width) for resizing
    
    Returns:
        preprocessed image array and original image info
    """
    # Load NIfTI
    nii_img = nib.load(image_path)
    img_data = nii_img.get_fdata()
    
    # Get original shape and affine for later
    original_shape = img_data.shape
    original_affine = nii_img.affine
    
    print(f"Original image shape: {original_shape}")
    print(f"Image dimensions: {img_data.ndim}D")
    
    # Apply same preprocessing as during training
    # Flip and rotate to match training orientation
    img_data = np.flip(np.rot90(img_data, -1), 1)
    
    # Handle different dimensionalities
    if img_data.ndim == 3:
        # Single channel: (H, W, D)
        img_data = np.transpose(img_data, (2, 0, 1))  # (D, H, W)
    elif img_data.ndim == 4:
        # Multi-channel: (H, W, D, C)
        img_data = np.transpose(img_data, (2, 0, 1, 3))  # (D, H, W, C)
    
    # Crop or pad to target size
    processed_slices = []
    for slice_idx in range(img_data.shape[0]):
        if img_data.ndim == 3:
            slice_data = img_data[slice_idx]
        else:
            slice_data = img_data[slice_idx, ..., 0]  # Take first channel
        
        # Crop/pad to target size
        if slice_data.shape[0] != target_size[0] or slice_data.shape[1] != target_size[1]:
            slice_data = crop_or_pad_3d(slice_data[..., np.newaxis], target_size)
            slice_data = slice_data[..., 0]
        
        # Normalize
        slice_data = normalize(slice_data)
        processed_slices.append(slice_data)
    
    # Stack back
    if img_data.ndim == 3:
        processed_img = np.stack(processed_slices, axis=0)
    else:
        processed_img = np.stack(processed_slices, axis=0)[..., np.newaxis]
    
    print(f"Processed image shape: {processed_img.shape}")
    
    return processed_img, original_shape, original_affine

# Load and preprocess image
try:
    img_processed, original_shape, original_affine = load_and_preprocess_image(image_path)
    print("✓ Image loaded and preprocessed successfully")
except Exception as e:
    print(f"✗ Error loading image: {e}")
    raise

## 5. Run Inference

Run the model on the preprocessed image to get segmentation.

In [None]:
# =============================================================================
# 5. Run Inference
# =============================================================================

def run_inference(model, image, batch_size=8):
    """
    Run inference on a 3D image slice by slice.
    
    Args:
        model: Keras model
        image: Preprocessed image array (D, H, W) or (D, H, W, C)
        batch_size: Batch size for inference
    
    Returns:
        Segmentation mask (D, H, W)
    """
    num_slices = image.shape[0]
    
    # Add channel dimension if needed
    if image.ndim == 3:
        image = image[..., np.newaxis]
    
    # Run inference in batches
    predictions = []
    for i in range(0, num_slices, batch_size):
        batch = image[i:i+batch_size]
        pred = model.predict(batch, verbose=0)
        predictions.append(pred)
    
    # Concatenate predictions
    segmentation = np.concatenate(predictions, axis=0)
    
    # Convert from one-hot to class indices
    segmentation = np.argmax(segmentation, axis=-1)
    
    return segmentation

print("Running inference...")
segmentation = run_inference(model, img_processed)
print(f"✓ Inference complete. Segmentation shape: {segmentation.shape}")
print(f"Unique classes: {np.unique(segmentation)}")

## 6. Load Ground Truth 

If a ground truth mask is provided, load it for evaluation.

In [None]:
# =============================================================================
# 6. Load Ground Truth (Optional)
# =============================================================================

def load_ground_truth(mask_path, original_shape):
    """
    Load ground truth mask and preprocess for comparison.
    
    Args:
        mask_path: Path to ground truth mask
        original_shape: Original image shape for reference
    
    Returns:
        Ground truth mask array
    """
    if not os.path.exists(mask_path):
        return None
    
    # Load mask
    mask_nii = nib.load(mask_path)
    mask_data = mask_nii.get_fdata()
    
    # Apply same preprocessing
    mask_data = np.flip(np.rot90(mask_data, -1), 1)
    mask_data = np.transpose(mask_data, (2, 0, 1))
    
    return mask_data

# Load ground truth if available
ground_truth = None
if mask_path and os.path.exists(mask_path):
    ground_truth = load_ground_truth(mask_path, original_shape)
    print(f"✓ Ground truth loaded. Shape: {ground_truth.shape}")
else:
    print("No ground truth mask provided or file not found.")

## 7. Calculate Metrics 

If ground truth is available, calculate Dice scores and other metrics.

In [None]:
# =============================================================================
# 7. Calculate Metrics (Optional)
# =============================================================================

def calculate_dice_score(pred, gt, class_idx, smooth=1e-6):
    """
    Calculate Dice score for a specific class.
    
    Args:
        pred: Prediction mask
        gt: Ground truth mask
        class_idx: Class index to calculate Dice for
        smooth: Smoothing factor
    
    Returns:
        Dice score
    """
    pred_class = (pred == class_idx).astype(np.float32)
    gt_class = (gt == class_idx).astype(np.float32)
    
    intersection = np.sum(pred_class * gt_class)
    union = np.sum(pred_class) + np.sum(gt_class)
    
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice

if ground_truth is not None:
    print("\n=== Segmentation Metrics ===")
    
    # Class names based on number of classes
    if cfg.num_classes == 3:
        class_names = ['Background', 'Right Kidney', 'Left Kidney']
    else:
        class_names = ['Background', 'Right Cortex', 'Left Cortex', 
                       'Right Medulla', 'Left Medulla']
    
    # Calculate Dice for each class
    dice_scores = {}
    for class_idx in range(1, cfg.num_classes):  # Skip background
        dice = calculate_dice_score(segmentation, ground_truth, class_idx)
        dice_scores[class_names[class_idx]] = dice
        print(f"{class_names[class_idx]}: {dice:.4f}")
    
    # Calculate mean Dice (excluding background)
    mean_dice = np.mean(list(dice_scores.values()))
    print(f"\nMean Dice (excluding background): {mean_dice:.4f}")
else:
    print("No ground truth available for metric calculation.")

## 8. Visualize Results

Visualize the segmentation results on sample slices.

In [None]:
# =============================================================================
# 8. Visualize Results
# =============================================================================

def visualize_segmentation(image, segmentation, ground_truth=None, num_slices=5, save_path=None):
    """
    Visualize segmentation results on sample slices.
    
    Args:
        image: Original image (D, H, W)
        segmentation: Predicted segmentation (D, H, W)
        ground_truth: Optional ground truth mask
        num_slices: Number of slices to display
        save_path: Path to save figure
    """
    # Select evenly spaced slices
    total_slices = image.shape[0]
    slice_indices = np.linspace(0, total_slices-1, num_slices, dtype=int)
    
    # Create colormap for segmentation
    if cfg.num_classes == 3:
        colors = ['black', 'red', 'blue']  # BG, Right, Left
    else:
        colors = ['black', 'red', 'blue', 'yellow', 'green']
    
    # Create figure
    if ground_truth is not None:
        fig, axes = plt.subplots(num_slices, 3, figsize=(15, 4*num_slices))
    else:
        fig, axes = plt.subplots(num_slices, 2, figsize=(12, 4*num_slices))
    
    # Handle single slice case
    if num_slices == 1:
        axes = axes.reshape(1, -1)
    
    for i, slice_idx in enumerate(slice_indices):
        # Original image
        axes[i, 0].imshow(image[slice_idx], cmap='gray')
        axes[i, 0].set_title(f'Original Image - Slice {slice_idx}')
        axes[i, 0].axis('off')
        
        # Overlay segmentation on image
        axes[i, 1].imshow(image[slice_idx], cmap='gray')
        seg_overlay = np.ma.masked_where(segmentation[slice_idx] == 0, segmentation[slice_idx])
        axes[i, 1].imshow(seg_overlay, cmap='tab10', alpha=0.5, vmin=0, vmax=cfg.num_classes-1)
        axes[i, 1].set_title(f'Predicted Segmentation - Slice {slice_idx}')
        axes[i, 1].axis('off')
        
        # Ground truth if available
        if ground_truth is not None:
            axes[i, 2].imshow(image[slice_idx], cmap='gray')
            gt_overlay = np.ma.masked_where(ground_truth[slice_idx] == 0, ground_truth[slice_idx])
            axes[i, 2].imshow(gt_overlay, cmap='tab10', alpha=0.5, vmin=0, vmax=cfg.num_classes-1)
            axes[i, 2].set_title(f'Ground Truth - Slice {slice_idx}')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
    
    plt.show()

# Visualize results
print("Visualizing segmentation results...")
visualize_segmentation(
    img_processed if img_processed.ndim == 3 else img_processed[..., 0],
    segmentation,
    ground_truth,
    num_slices=min(5, img_processed.shape[0]),
    save_path=os.path.join(output_dir, 'segmentation_visualization.png')
)

## 9. Save Segmentation Mask

Save the segmentation mask as a NIfTI file.

In [None]:
# =============================================================================
# 9. Save Segmentation Mask
# =============================================================================

def save_segmentation_mask(segmentation, original_affine, output_path):
    """
    Save segmentation mask as NIfTI file.
    
    Args:
        segmentation: Segmentation mask (D, H, W)
        original_affine: Original affine transformation
        output_path: Path to save NIfTI file
    """
    # Transpose back to original orientation
    segmentation = np.transpose(segmentation, (1, 2, 0))
    segmentation = np.flip(np.rot90(segmentation, 1), 0)
    
    # Create NIfTI image
    nii_img = nib.Nifti1Image(segmentation.astype(np.int16), original_affine)
    nib.save(nii_img, output_path)
    print(f"Segmentation mask saved to: {output_path}")

# Save segmentation mask
output_mask_path = os.path.join(output_dir, 'segmentation_mask.nii.gz')
save_segmentation_mask(segmentation, original_affine, output_mask_path)

## 10. Batch Processing 

Process multiple images in a directory.

In [None]:
# =============================================================================
# 10. Batch Processing (Optional)
# =============================================================================

def batch_process_directory(input_dir, output_dir, model, file_pattern="*.nii*"):
    """
    Process all NIfTI files in a directory.
    
    Args:
        input_dir: Input directory containing NIfTI files
        output_dir: Output directory for segmentation masks
        model: Keras model
        file_pattern: Pattern to match NIfTI files
    """
    import glob
    
    os.makedirs(output_dir, exist_ok=True)
    nifti_files = glob.glob(os.path.join(input_dir, file_pattern))
    
    print(f"Found {len(nifti_files)} NIfTI files to process")
    
    for i, file_path in enumerate(nifti_files):
        print(f"\nProcessing [{i+1}/{len(nifti_files)}]: {os.path.basename(file_path)}")
        
        try:
            # Load and preprocess
            img_processed, _, original_affine = load_and_preprocess_image(file_path)
            
            # Run inference
            segmentation = run_inference(model, img_processed)
            
            # Save result
            output_filename = os.path.basename(file_path).replace('.nii', '_seg.nii')
            output_path = os.path.join(output_dir, output_filename)
            
            # Transpose back
            seg_to_save = np.transpose(segmentation, (1, 2, 0))
            seg_to_save = np.flip(np.rot90(seg_to_save, 1), 0)
            
            # Save
            nii_img = nib.Nifti1Image(seg_to_save.astype(np.int16), original_affine)
            nib.save(nii_img, output_path)
            
            print(f"✓ Saved: {output_path}")
            
        except Exception as e:
            print(f"✗ Error processing {file_path}: {e}")

# Uncomment to run batch processing
# input_directory = "/path/to/your/images"
# batch_output_dir = os.path.join(output_dir, "batch_results")
# batch_process_directory(input_directory, batch_output_dir, model)

## Summary

This notebook demonstrated:
1. Loading a pre-trained segmentation model
2. Preprocessing a NIfTI image for inference
3. Running slice-by-slice inference
4. Visualizing segmentation results
5. Calculating evaluation metrics (if ground truth available)
6. Saving segmentation masks as NIfTI files

The saved segmentation mask can be found at: `{output_dir}/segmentation_mask.nii.gz`