# Renal MRI Registration Inference Demo

This notebook demonstrates how to load pre-trained affine and non-rigid registration models to align all MRI contrasts to a DIXON template and propagate segmentation masks.

## Overview
1. Load pre-trained affine and non-rigid registration models
2. Load DIXON template image and its segmentation mask
3. Load moving images (all other contrasts) from resampled format
4. Apply affine registration to align moving images to DIXON space
5. Apply non-rigid registration for fine-grained alignment
6. Propagate DIXON mask to all other contrasts using the computed deformations
7. Visualize and save results

In [None]:
# =============================================================================
# 1. Setup and Imports
# =============================================================================

import os
import sys
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from pathlib import Path
import tensorflow as tf
import glob
from typing import List, Tuple, Dict, Optional

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Import project modules
from models.model_reg_gn import modelObj
from data.data_utils import normalize, crop_or_pad_3d

# Set up plotting style
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

## 2. Configuration

Set the paths to your models and data.

In [None]:
# =============================================================================
# 2. Configuration
# =============================================================================

# Model configuration
model_config = {
    'img_size_x': 256,
    'img_size_y': 256,
    'num_channels': 1,
    'num_classes': 1,
    'num_contrasts': 6,  # Number of moving contrasts
    'weighted': True      # Match training configuration
}

# Paths to pre-trained registration models
affine_model_path = "/path/to/weights/registration_affine_checkpoint.hdf5"
nonrigid_model_path = "/path/to/weights/registration_nonrigid_checkpoint.hdf5"

# Subject to process
subject_id = "subj"  # Change to your subject
data_dir = "/path/to/data/"

# Template contrast (fixed image)
template_contrast = "DIXON"

# Moving contrasts (images to register to template)
moving_contrasts = ["DIXON", "BOLD", "T1_mapping_VIBE", "T1_mapping_fl2d", "ASL", "Diffusion"]

# All contrasts (for reference)
all_contrasts = [template_contrast] + moving_contrasts

# Output directory for results
output_dir = f"./registration_results/{subject_id}"
os.makedirs(output_dir, exist_ok=True)

print(f"Subject: {subject_id}")
print(f"Template contrast: {template_contrast}")
print(f"Moving contrasts: {moving_contrasts}")
print(f"Output directory: {output_dir}")
print(f"Affine model: {affine_model_path}")
print(f"Non-rigid model: {nonrigid_model_path}")

## 3. Helper Functions

Define functions for loading data and applying registration.

In [None]:
# =============================================================================
# 3. Helper Functions
# =============================================================================

class Config:
    """Simple config class to hold model parameters."""
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            setattr(self, key, value)


def load_nifti(file_path: str) -> Tuple[np.ndarray, np.ndarray, Tuple]:
    """
    Load NIfTI file and return data, affine, and original shape.
    
    Args:
        file_path: Path to NIfTI file
    
    Returns:
        Tuple of (data, affine, original_shape)
    """
    nii = nib.load(file_path)
    data = nii.get_fdata()
    return data, nii.affine, data.shape


def find_contrast_files(subject_dir: str, contrast: str) -> List[str]:
    """
    Find all image files for a given contrast.
    
    Args:
        subject_dir: Subject directory
        contrast: Contrast name
    
    Returns:
        List of image file paths
    """
    pattern = os.path.join(subject_dir, contrast, "imagesTr", "*.nii*")
    return sorted(glob.glob(pattern))


def find_mask_files(subject_dir: str, contrast: str, side: str = None) -> List[str]:
    """
    Find mask files for a given contrast.
    
    Args:
        subject_dir: Subject directory
        contrast: Contrast name
        side: 'left', 'right', or None for both
    
    Returns:
        List of mask file paths
    """
    pattern = os.path.join(subject_dir, contrast, "labelsTr", "*.nii*")
    all_masks = glob.glob(pattern)
    
    if side:
        return [m for m in all_masks if side.lower() in m.lower()]
    return all_masks


def load_and_preprocess_contrast(subject_dir: str, contrast: str) -> np.ndarray:
    """
    Load all images for a contrast and preprocess for registration.
    
    Args:
        subject_dir: Subject directory
        contrast: Contrast name
    
    Returns:
        Preprocessed image array with shape (D, H, W, C)
    """
    # Find all image files for this contrast
    img_files = find_contrast_files(subject_dir, contrast)
    
    if not img_files:
        print(f"⚠ No images found for {contrast}")
        return None
    
    # Load and preprocess each image
    images = []
    for img_file in img_files:
        # Load NIfTI
        img = nib.load(img_file)
        data = img.get_fdata().astype(np.float32)
        
        # Apply same transformations as in training
        data = np.flip(np.rot90(data, -1), 1)
        images.append(data)
    
    # Stack images
    if len(images) == 1:
        # Single channel: (H, W, D) -> (D, H, W, 1)
        img_stack = np.transpose(images[0], (2, 0, 1))[..., np.newaxis]
    else:
        # Multiple channels: stack along new axis
        img_stack = np.stack(images, axis=-1)  # (H, W, D, C)
        img_stack = np.transpose(img_stack, (2, 0, 1, 3))  # (D, H, W, C)
    
    return img_stack


def load_mask(subject_dir: str, contrast: str) -> np.ndarray:
    """
    Load and combine left/right masks for a contrast.
    
    Args:
        subject_dir: Subject directory
        contrast: Contrast name
    
    Returns:
        Combined mask with values: 0=bg, 1=left, 2=right, shape (D, H, W)
    """
    # Find left and right masks
    left_masks = find_mask_files(subject_dir, contrast, side='left')
    right_masks = find_mask_files(subject_dir, contrast, side='right')
    
    if not left_masks or not right_masks:
        print(f"⚠ Missing masks for {contrast}")
        return None
    
    # Load left mask
    left_nii = nib.load(left_masks[0])
    left_data = left_nii.get_fdata().astype(np.float32)
    left_data = np.flip(np.rot90(left_data, -1), 1)
    
    # Load right mask
    right_nii = nib.load(right_masks[0])
    right_data = right_nii.get_fdata().astype(np.float32)
    right_data = np.flip(np.rot90(right_data, -1), 1)
    
    # Ensure same shape and orientation
    if left_data.shape != right_data.shape:
        print(f"⚠ Mask shape mismatch: left {left_data.shape}, right {right_data.shape}")
        return None
    
    # Create combined mask: 1=left, 2=right
    combined = np.zeros(left_data.shape, dtype=np.float32)
    combined[left_data > 0] = 1
    combined[right_data > 0] = 2
    
    # Transpose to (D, H, W)
    combined = np.transpose(combined, (2, 0, 1))
    
    return combined


def warp_mask(mask: np.ndarray, flow_field: np.ndarray) -> np.ndarray:
    """
    Warp a mask using a flow field.
    
    Args:
        mask: Input mask (D, H, W)
        flow_field: Flow field (D, H, W, 2) with displacements in x,y
    
    Returns:
        Warped mask (D, H, W)
    """
    D, H, W = mask.shape
    
    # Create meshgrid
    x = np.arange(W)
    y = np.arange(H)
    xx, yy = np.meshgrid(x, y, indexing='xy')
    
    warped_mask = np.zeros_like(mask)
    
    for d in range(D):
        # Get flow for this slice
        flow = flow_field[d]  # (H, W, 2)
        
        # Compute new coordinates
        new_x = xx + flow[..., 0]
        new_y = yy + flow[..., 1]
        
        # Clip to image boundaries
        new_x = np.clip(new_x, 0, W-1).astype(np.int32)
        new_y = np.clip(new_y, 0, H-1).astype(np.int32)
        
        # Sample mask at new coordinates
        warped_mask[d, new_y, new_x] = mask[d, yy, xx]
    
    return warped_mask


def save_nifti(data: np.ndarray, affine: np.ndarray, filename: str):
    """
    Save data as NIfTI file with proper orientation.
    
    Args:
        data: Image data
        affine: Affine transformation matrix
        filename: Output filename
    """
    # Inverse the preprocessing transformations
    if data.ndim == 3:
        data = np.transpose(data, (1, 2, 0))  # (H, W, D)
        data = np.flip(np.rot90(data, 1), 0)
    
    nii = nib.Nifti1Image(data.astype(np.float32), affine)
    nib.save(nii, filename)
    print(f"  Saved: {filename}")

## 4. Load Models

Load the pre-trained affine and non-rigid registration models.

In [None]:
# =============================================================================
# 4. Load Models
# =============================================================================

cfg = Config(model_config)
mm_utils = modelObj(cfg)

# Load affine model
print("Loading affine registration model...")
affine_model = mm_utils.reg_groupwise_affine(weighted=cfg.weighted)

if os.path.exists(affine_model_path):
    affine_model.load_weights(affine_model_path)
    print(f"✓ Affine model loaded from: {affine_model_path}")
else:
    print(f"✗ Affine model not found at: {affine_model_path}")

# Load non-rigid model
print("\nLoading non-rigid registration model...")
nonrigid_model = mm_utils.reg_groupwise_affine_nonrigid(
    checkpoint_affine=affine_model_path,
    test_mode=True, weighted=False
)

if os.path.exists(nonrigid_model_path):
    nonrigid_model.load_weights(nonrigid_model_path)
    print(f"✓ Non-rigid model loaded from: {nonrigid_model_path}")
else:
    print(f"✗ Non-rigid model not found at: {nonrigid_model_path}")

## 5. Load Subject Data

Load DIXON template, its mask, and all moving contrasts.

In [None]:
# =============================================================================
# 5. Load Subject Data 
# =============================================================================

subject_dir = os.path.join(data_dir, subject_id)
print(f"Loading data for {subject_id} from {subject_dir}")

def load_and_preprocess_contrast(subject_dir: str, contrast: str, target_channels: int = 1) -> np.ndarray:
    """
    Load all images for a contrast and preprocess for registration.
    
    Args:
        subject_dir: Subject directory
        contrast: Contrast name
        target_channels: Target number of channels (default: 1)
    
    Returns:
        Preprocessed image array with shape (D, H, W, target_channels)
    """
    # Find all image files for this contrast
    img_files = find_contrast_files(subject_dir, contrast)
    
    if not img_files:
        print(f"⚠ No images found for {contrast}")
        return None
    
    print(f"    Found {len(img_files)} image(s) for {contrast}")
    
    # Load and preprocess each image
    images = []
    for i, img_file in enumerate(img_files):
        # Load NIfTI
        img = nib.load(img_file)
        data = img.get_fdata().astype(np.float32)
        
        # Apply same transformations as in training
        data = np.flip(np.rot90(data, -1), 1)
        images.append(data)
    
    # Handle different dimensionalities
    if len(images) == 1:
        # Single file case
        img_data = images[0]
        
        if img_data.ndim == 3:
            # 3D image: (H, W, D) -> (D, H, W)
            img_stack = np.transpose(img_data, (2, 0, 1))
            # Add channel dimension if needed
            if target_channels > 1:
                # Duplicate to match target channels
                img_stack = np.stack([img_stack] * target_channels, axis=-1)
            else:
                img_stack = img_stack[..., np.newaxis]
                
        elif img_data.ndim == 4:
            # 4D image: (H, W, D, C) -> (D, H, W, C)
            img_stack = np.transpose(img_data, (2, 0, 1, 3))
            # If we have more channels than target, take first target_channels
            if img_stack.shape[-1] > target_channels:
                img_stack = img_stack[..., :target_channels]
            # If we have fewer channels, duplicate to match
            elif img_stack.shape[-1] < target_channels:
                repeats = target_channels // img_stack.shape[-1] + 1
                img_stack = np.tile(img_stack, (1, 1, 1, repeats))[..., :target_channels]
        else:
            print(f"        Unexpected dimensions: {img_data.ndim}D")
            return None
    else:
        # Multiple files case - treat each file as a channel
        print(f"        Stacking {len(images)} images as channels")
        
        # First, convert each to (D, H, W)
        processed = []
        for img_data in images:
            if img_data.ndim == 3:
                processed.append(np.transpose(img_data, (2, 0, 1)))
            else:
                print(f"        Unexpected dimensions for multi-file: {img_data.ndim}D")
                return None
        
        # Stack along channel dimension
        img_stack = np.stack(processed, axis=-1)  # (D, H, W, C)
        
        # Adjust to target channels if needed
        if img_stack.shape[-1] > target_channels:
            img_stack = img_stack[..., :target_channels]
        elif img_stack.shape[-1] < target_channels:
            repeats = target_channels // img_stack.shape[-1] + 1
            img_stack = np.tile(img_stack, (1, 1, 1, repeats))[..., :target_channels]
    
    print(f"    Final shape: {img_stack.shape}")
    return img_stack


def load_mask(subject_dir: str, contrast: str) -> np.ndarray:
    """
    Load and combine left/right masks for a contrast.
    
    Returns:
        Combined mask with values: 0=bg, 1=left, 2=right, shape (D, H, W)
    """
    # Find left and right masks
    left_masks = find_mask_files(subject_dir, contrast, side='left')
    right_masks = find_mask_files(subject_dir, contrast, side='right')
    
    if not left_masks or not right_masks:
        print(f"    ⚠ Missing masks for {contrast}")
        return None
    
    # Load left mask
    left_nii = nib.load(left_masks[0])
    left_data = left_nii.get_fdata().astype(np.float32)
    left_data = np.flip(np.rot90(left_data, -1), 1)
    
    # Load right mask
    right_nii = nib.load(right_masks[0])
    right_data = right_nii.get_fdata().astype(np.float32)
    right_data = np.flip(np.rot90(right_data, -1), 1)
    
    # Create combined mask: 1=left, 2=right
    combined = np.zeros(left_data.shape, dtype=np.float32)
    combined[left_data > 0] = 1
    combined[right_data > 0] = 2
    
    # Transpose to (D, H, W)
    combined = np.transpose(combined, (2, 0, 1))
    
    return combined


# Load template (DIXON) - force to 1 channel
print(f"\nLoading template contrast: {template_contrast}")
template_data = load_and_preprocess_contrast(subject_dir, template_contrast, target_channels=1)

if template_data is None:
    raise ValueError(f"Could not load {template_contrast} for {subject_id}")

print(f"  Template shape: {template_data.shape}")  # Should be (D, H, W, 1)

# Load template mask
template_mask = load_mask(subject_dir, template_contrast)

if template_mask is not None:
    print(f"  Template mask shape: {template_mask.shape}")
    print(f"  Mask values: {np.unique(template_mask)}")
else:
    print(f"  ⚠ No mask found for {template_contrast}")

# Load moving contrasts
moving_data = []
moving_masks = []
valid_contrasts = []

print("\nLoading moving contrasts:")
for contrast in moving_contrasts:
    print(f"  {contrast}:")
    
    # Load image data - force to 1 channel
    data = load_and_preprocess_contrast(subject_dir, contrast, target_channels=1)
    if data is None:
        print(f"    Skipping {contrast} due to loading error")
        continue
    
    moving_data.append(data)
    
    # Load mask
    mask = load_mask(subject_dir, contrast)
    if mask is not None:
        print(f"    Mask shape: {mask.shape}")
        moving_masks.append(mask)
    else:
        print(f"    ⚠ No mask found, creating empty mask")
        empty_mask = np.zeros(data.shape[:3], dtype=np.float32)
        moving_masks.append(empty_mask)
    
    valid_contrasts.append(contrast)

if not moving_data:
    raise ValueError("No moving contrasts could be loaded")

# Get dimensions
D, H, W, C = template_data.shape
num_moving = len(valid_contrasts)
print(f"\nDimensions: D={D}, H={H}, W={W}, C={C}, num_moving={num_moving}")

# Prepare data for model - need to reshape to (batch, num_contrasts, H, W, C)
# The model expects each sample to have all contrasts for a single slice
# So we need to reshape from (D, H, W, C) to (D, num_contrasts, H, W, C) for moving
# and (D, H, W, C) to (D, 1, H, W, C) for template (will be tiled)

# Reshape moving data: from list of (D, H, W, C) to (D, num_moving, H, W, C)
moving_data_per_slice = []
for slice_idx in range(D):
    slice_data = []
    for contrast_data in moving_data:
        slice_data.append(contrast_data[slice_idx:slice_idx+1, :, :, :])
    # Stack along contrast dimension
    slice_stack = np.concatenate(slice_data, axis=0)  # (num_moving, H, W, C)
    moving_data_per_slice.append(slice_stack)

# Stack along batch dimension (slices become batch)
moving_data_stack = np.stack(moving_data_per_slice, axis=0)  # (D, num_moving, H, W, C)
print(f"Moving data stack shape: {moving_data_stack.shape}")

# Similarly for masks
moving_masks_per_slice = []
for slice_idx in range(D):
    slice_masks = []
    for contrast_mask in moving_masks:
        slice_masks.append(contrast_mask[slice_idx:slice_idx+1, :, :])
    # Stack along contrast dimension
    slice_stack = np.concatenate(slice_masks, axis=0)  # (num_moving, H, W)
    moving_masks_per_slice.append(slice_stack)

moving_masks_stack = np.stack(moving_masks_per_slice, axis=0)  # (D, num_moving, H, W)
print(f"Moving masks stack shape: {moving_masks_stack.shape}")

# Template needs to be repeated for each contrast
template_per_slice = []
for slice_idx in range(D):
    # Template for this slice, repeated for each contrast
    template_slice = template_data[slice_idx:slice_idx+1, :, :, :]  # (1, H, W, C)
    template_slice = np.tile(template_slice, (num_moving, 1, 1, 1))  # (num_moving, H, W, C)
    template_per_slice.append(template_slice)

template_stack = np.stack(template_per_slice, axis=0)  # (D, num_moving, H, W, C)
print(f"Template stack shape: {template_stack.shape}")

# Now we have:
# - template_stack: (D, num_moving, H, W, C)
# - moving_data_stack: (D, num_moving, H, W, C)
# - moving_masks_stack: (D, num_moving, H, W)

print(f"\nFinal shapes for model input:")
print(f"  Template: {template_stack.shape}")
print(f"  Moving: {moving_data_stack.shape}")
print(f"  Moving masks: {moving_masks_stack.shape}")

In [None]:
template_data.shape, moving_data_stack.shape, moving_masks_stack.shape

## 6. Apply Affine Registration

First, apply affine registration to get coarse alignment.

In [None]:
# =============================================================================
# 6. Apply Affine Registration
# =============================================================================

print("\nApplying affine registration...")

def apply_affine_registration(model, template, moving_images, moving_masks):
    """
    Apply affine registration to align moving images to template.
    
    Args:
        model: Affine registration model
        template: Template images (D, num_moving, H, W, C)
        moving_images: Moving images (D, num_moving, H, W, C)
        moving_masks: Moving masks (D, num_moving, H, W)
    
    Returns:
        Dictionary with registration results
    """
    D, num_moving, H, W, C = template.shape
    
    print(f"  D={D}, num_moving={num_moving}, H={H}, W={W}, C={C}")
    
    # Prepare inputs for model
    # Model expects: [template_deformable, moving_deformable, moving_lbl_deformable, weight_map]
    # Each with shape (batch, num_moving, H, W, C) where batch = D
    
    # 1. Template input
    template_input = template
    print(f"  Template input shape: {template_input.shape}")
    
    # 2. Moving images input
    moving_input = moving_images
    print(f"  Moving input shape: {moving_input.shape}")
    
    # 3. Moving labels input - need to add channel dimension
    moving_labels_input = moving_masks[..., np.newaxis]  # (D, num_moving, H, W, 1)
    print(f"  Moving labels shape: {moving_labels_input.shape}")
    
    # 4. Weight map input
    weight_map = np.ones((D, num_moving, H, W, 1), dtype=np.float32)
    print(f"  Weight map shape: {weight_map.shape}")
    
    # Run inference
    outputs = model.predict([template_input, moving_input, moving_labels_input, weight_map], verbose=1)
    
    # Parse outputs
    if isinstance(outputs, list):
        print(f"  Number of outputs: {len(outputs)}")
        warped_images = outputs[0]  # Should be (D*num_moving, H, W, 1)
        
        # Reshape back
        warped_images = warped_images.reshape(D, num_moving, H, W, -1)
        print(f"  Reshaped warped images: {warped_images.shape}")
    else:
        print(f"  Output shape: {outputs.shape}")
        warped_images = outputs.reshape(D, num_moving, H, W, -1)
    
    return {
        'warped_images': warped_images,
        'full_outputs': outputs
    }


# Apply affine registration
affine_results = apply_affine_registration(
    model=affine_model,
    template=template_stack,
    moving_images=moving_data_stack,
    moving_masks=moving_masks_stack
)

affine_warped = affine_results['warped_images']
print(f"Affine warped images shape: {affine_warped.shape}")

## 7. Apply Non-Rigid Registration

Then, apply non-rigid registration for fine-grained alignment and get flow fields.

In [None]:
# =============================================================================
# 7. Apply Non-Rigid Registration 
# =============================================================================

print("\nApplying non-rigid registration...")

def apply_nonrigid_registration(model, template, moving_images, moving_masks):
    """
    Apply non-rigid registration to align moving images to template.
    
    Args:
        model: Non-rigid registration model (expects 3 inputs)
        template: Template images (D, num_moving, H, W, C)
        moving_images: Moving images (D, num_moving, H, W, C)
        moving_masks: Moving masks (D, num_moving, H, W)
    
    Returns:
        Dictionary with registration results including flow fields
    """
    D, num_moving, H, W, C = template.shape
    
    print(f"  D={D}, num_moving={num_moving}, H={H}, W={W}, C={C}")
    
    # Prepare inputs for model
    # Model expects: [template_deformable, moving_deformable, moving_lbl_deformable]
    # Each with shape (batch, num_moving, H, W, C) where batch = D
    
    # 1. Template input
    template_input = template
    print(f"  Template input shape: {template_input.shape}")
    
    # 2. Moving images input
    moving_input = moving_images
    print(f"  Moving input shape: {moving_input.shape}")
    
    # 3. Moving labels input - need to add channel dimension
    moving_labels_input = moving_masks[..., np.newaxis]  # (D, num_moving, H, W, 1)
    print(f"  Moving labels shape: {moving_labels_input.shape}")
    
    # Run inference with only 3 inputs (no weight map for non-rigid model)
    inputs = [template_input, moving_input, moving_labels_input]
    print(f"  Total inputs: {len(inputs)}")
    
    outputs = model.predict(inputs, verbose=1)
    
    # Parse outputs for non-rigid model
    # Based on the loss configuration, outputs are typically:
    # [warped_template, pred_fix_lbl, flow, ...]
    if isinstance(outputs, list):
        print(f"  Number of outputs: {len(outputs)}")
        for i, out in enumerate(outputs):
            print(f"    Output {i} shape: {out.shape}")
        
        # First output: warped images (D*num_moving, H, W, 1)
        warped_images = outputs[0]
        
        # Third output: flow fields (D*num_moving, H, W, 2)
        flow_fields = outputs[2] if len(outputs) > 2 else None
        
        # Reshape warped images back to (D, num_moving, H, W, C)
        warped_images = warped_images.reshape(D, num_moving, H, W, -1)
        print(f"  Reshaped warped images: {warped_images.shape}")
        
        # Reshape flow fields if present
        if flow_fields is not None:
            flow_fields = flow_fields.reshape(D, num_moving, H, W, 2)
            print(f"  Flow fields shape: {flow_fields.shape}")
            
            # Check flow field range
            print(f"  Flow field stats - min: {flow_fields.min():.3f}, "
                  f"max: {flow_fields.max():.3f}, mean: {flow_fields.mean():.3f}")
    else:
        print(f"  Output shape: {outputs.shape}")
        warped_images = outputs.reshape(D, num_moving, H, W, -1)
        flow_fields = None
    
    return {
        'warped_images': warped_images,
        'flow_fields': flow_fields,
        'full_outputs': outputs
    }


# Apply non-rigid registration (without weight map)
nonrigid_results = apply_nonrigid_registration(
    model=nonrigid_model,
    template=template_stack,
    moving_images=moving_data_stack,
    moving_masks=moving_masks_stack
)

nonrigid_warped = nonrigid_results['warped_images']
flow_fields = nonrigid_results['flow_fields']

print(f"\nNon-rigid warped images shape: {nonrigid_warped.shape}")
if flow_fields is not None:
    print(f"Flow fields shape: {flow_fields.shape}")

## 8. Propagate DIXON Mask to Other Contrasts

Use the flow fields to warp the DIXON mask to each moving contrast.

In [None]:
# =============================================================================
# 8. Propagate DIXON Mask to Other Contrasts Using Flow Fields
# =============================================================================

if template_mask is not None and flow_fields is not None:
    print("\nPropagating DIXON mask to all moving contrasts...")
    
    print(f"Flow fields shape: {flow_fields.shape}")  # (32, 6, 256, 256, 2)
    print(f"Template mask shape: {template_mask.shape}")  # (32, 256, 256)
    print(f"Number of moving contrasts: {len(valid_contrasts)}")
    
    def warp_mask(mask: np.ndarray, flow_field: np.ndarray) -> np.ndarray:
        """
        Warp a mask using a flow field.
        
        Args:
            mask: Input mask (D, H, W)
            flow_field: Flow field (D, H, W, 2) with displacements in x,y
        
        Returns:
            Warped mask (D, H, W)
        """
        D, H, W = mask.shape
        print(f"    Warping: mask shape {mask.shape}, flow shape {flow_field.shape}")
        
        # Create meshgrid
        x = np.arange(W)
        y = np.arange(H)
        xx, yy = np.meshgrid(x, y, indexing='xy')
        
        warped_mask = np.zeros_like(mask)
        
        for d in range(D):
            # Get flow for this slice
            flow_d = flow_field[d]  # (H, W, 2)
            
            # Compute new coordinates
            new_x = xx + flow_d[..., 0]
            new_y = yy + flow_d[..., 1]
            
            # Clip to image boundaries
            new_x = np.clip(new_x, 0, W-1).astype(np.int32)
            new_y = np.clip(new_y, 0, H-1).astype(np.int32)
            
            # Sample mask at new coordinates (backward warping)
            warped_mask[d, yy, xx] = mask[d, new_y, new_x]
        
        return warped_mask
    
    warped_masks = []
    
    # flow_fields shape: (32, 6, 256, 256, 2) - (slices, contrasts, H, W, 2)
    for i, contrast in enumerate(valid_contrasts):
        print(f"\n  Processing {contrast} (index {i}):")
        
        # Extract flow for this contrast across all slices
        # flow_fields[:, i, :, :, :] gives (32, 256, 256, 2)
        contrast_flow = flow_fields[:, i, :, :, :]  # (D, H, W, 2)
        print(f"    Contrast flow shape: {contrast_flow.shape}")
        
        # Warp template mask using flow
        warped_mask = warp_mask(template_mask, contrast_flow)
        print(f"    Warped mask shape: {warped_mask.shape}")
        print(f"    Warped mask values: {np.unique(warped_mask)}")
        
        warped_masks.append(warped_mask)
    
    if warped_masks:
        # Stack warped masks
        warped_masks_stack = np.stack(warped_masks, axis=0)  # (num_moving, D, H, W)
        print(f"\nAll warped masks shape: {warped_masks_stack.shape}")
    else:
        warped_masks_stack = None
        print("\nNo warped masks generated")
    
else:
    print("\nCannot propagate masks: missing template mask or flow fields")
    warped_masks_stack = None

## 9. Visualize Results

Compare original, affine-warped, and non-rigid warped images.

In [None]:
# =============================================================================
# 9. Visualize Results
# =============================================================================

def visualize_registration_results(
    template,
    original_moving,
    affine_warped,
    nonrigid_warped,
    contrast_names,
    slice_idx=None,
    num_slices=3,
    save_path=None
):
    """
    Visualize registration results for multiple contrasts.
    
    Args:
        template: Template image (D, H, W, 1) - slices first
        original_moving: Original moving images (D, num_moving, H, W, 1) - slices first, then contrasts
        affine_warped: Affine-warped images (D, num_moving, H, W, 1) - slices first, then contrasts
        nonrigid_warped: Non-rigid warped images (D, num_moving, H, W, 1) - slices first, then contrasts
        contrast_names: Names of contrasts
        slice_idx: Specific slice to show (if None, show middle slice)
        num_slices: Number of slices to show
        save_path: Path to save figure
    """
    num_moving = len(contrast_names)
    D = template.shape[0]
    
    print(f"Debug - Template shape: {template.shape}")
    print(f"Debug - Original moving shape: {original_moving.shape}")
    print(f"Debug - Affine warped shape: {affine_warped.shape}")
    print(f"Debug - Non-rigid warped shape: {nonrigid_warped.shape}")
    
    # Select slices
    if slice_idx is None:
        # Show slices around the middle
        mid = D // 2
        slice_indices = np.linspace(mid - num_slices//2, mid + num_slices//2, num_slices, dtype=int)
        slice_indices = np.clip(slice_indices, 0, D-1)
    else:
        slice_indices = [slice_idx]
        num_slices = 1
    
    print(f"Selected slice indices: {slice_indices}")
    
    # Create figure
    fig, axes = plt.subplots(
        num_slices, num_moving + 1,  # +1 for template
        figsize=(4*(num_moving+1), 4*num_slices)
    )
    
    if num_slices == 1:
        axes = axes.reshape(1, -1)
    
    for i, current_slice in enumerate(slice_indices):
        # Template column - template has slices first
        axes[i, 0].imshow(template[current_slice, :, :, 0], cmap='gray')
        axes[i, 0].set_title(f'Template (DIXON)\nSlice {current_slice}')
        axes[i, 0].axis('off')
        
        # Moving contrasts - data has slices first, then contrasts
        for j, (name, nonrigid) in enumerate(zip(contrast_names, nonrigid_warped[current_slice])):
            # nonrigid_warped[current_slice] gives (num_moving, H, W, 1)
            # We need to index by j to get the specific contrast
            axes[i, j+1].imshow(nonrigid_warped[current_slice, j, :, :, 0], cmap='gray')
            axes[i, j+1].set_title(f'{name} (Non-rigid)')
            axes[i, j+1].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
    
    plt.show()


# Visualize results
print("\nVisualizing registration results...")
visualize_registration_results(
    template=template_data,
    original_moving=moving_data_stack,
    affine_warped=affine_warped,
    nonrigid_warped=nonrigid_warped,
    contrast_names=valid_contrasts,
    num_slices=3,
    save_path=os.path.join(output_dir, 'registration_results.png')
)

## 10. Visualize Warped Masks

Show the propagated DIXON masks overlaid on the warped images.

In [None]:
# =============================================================================
# 10. Visualize Warped Masks 
# =============================================================================

if warped_masks_stack is not None:
    def visualize_warped_masks(
        template,
        template_mask,
        warped_images,
        warped_masks,
        contrast_names,
        slice_idx=None,
        num_slices=3,
        save_path=None
    ):
        """
        Visualize warped masks overlaid on registered images.
        
        Args:
            template: Template image (D, H, W, 1)
            template_mask: Template mask (D, H, W)
            warped_images: Warped images (D, num_moving, H, W, 1)
            warped_masks: Warped masks (num_moving, D, H, W)
            contrast_names: Names of contrasts
            slice_idx: Specific slice to show
            num_slices: Number of slices to show
            save_path: Path to save figure
        """
        num_moving = len(contrast_names)
        D = template.shape[0]
        
        print(f"Debug - Template shape: {template.shape}")
        print(f"Debug - Template mask shape: {template_mask.shape}")
        print(f"Debug - Warped images shape: {warped_images.shape}")
        print(f"Debug - Warped masks shape: {warped_masks.shape}")
        
        # Select slices
        if slice_idx is None:
            mid = D // 2
            slice_indices = np.linspace(mid - num_slices//2, mid + num_slices//2, num_slices, dtype=int)
            slice_indices = np.clip(slice_indices, 0, D-1)
        else:
            slice_indices = [slice_idx]
            num_slices = 1
        
        # Create figure
        fig, axes = plt.subplots(
            num_slices, num_moving + 1,
            figsize=(4*(num_moving+1), 4*num_slices)
        )
        
        if num_slices == 1:
            axes = axes.reshape(1, -1)
        
        for i, current_slice in enumerate(slice_indices):
            # Template with its mask
            axes[i, 0].imshow(template[current_slice, :, :, 0], cmap='gray')
            mask_overlay = np.ma.masked_where(template_mask[current_slice] == 0, template_mask[current_slice])
            axes[i, 0].imshow(mask_overlay, cmap='tab10', alpha=0.5, vmin=0, vmax=2)
            axes[i, 0].set_title(f'Template + Mask\nSlice {current_slice}')
            axes[i, 0].axis('off')
            
            # Moving contrasts with warped masks
            for j, name in enumerate(contrast_names):
                # warped_images: (D, num_moving, H, W, 1)
                # warped_masks: (num_moving, D, H, W)
                axes[i, j+1].imshow(warped_images[current_slice, j, :, :, 0], cmap='gray')
                mask_overlay = np.ma.masked_where(warped_masks[j, current_slice] == 0, warped_masks[j, current_slice])
                axes[i, j+1].imshow(mask_overlay, cmap='tab10', alpha=0.5, vmin=0, vmax=2)
                axes[i, j+1].set_title(f'{name} + Warped Mask')
                axes[i, j+1].axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"Figure saved to: {save_path}")
        
        plt.show()
    
    # Visualize warped masks
    print("\nVisualizing warped masks...")
    visualize_warped_masks(
        template=template_data,
        template_mask=template_mask,
        warped_images=nonrigid_warped,
        warped_masks=warped_masks_stack,
        contrast_names=valid_contrasts,
        num_slices=3,
        save_path=os.path.join(output_dir, 'warped_masks.png')
    )

## 11. Save Results

Save registered images and warped masks as NIfTI files.

In [None]:
# =============================================================================
# 11. Save Results
# =============================================================================

print("\nSaving registration results...")
print(f"Current data shapes:")
print(f"  nonrigid_warped: {nonrigid_warped.shape}")  # (32, 6, 256, 256, 1)
print(f"  affine_warped: {affine_warped.shape}")      # (32, 6, 256, 256, 1)
if 'warped_masks_stack' in locals():
    print(f"  warped_masks_stack: {warped_masks_stack.shape}")  # (6, 32, 256, 256)
if flow_fields is not None:
    print(f"  flow_fields: {flow_fields.shape}")      # (32, 6, 256, 256, 2)

# Get original affine from DIXON for reference
dixon_files = find_contrast_files(subject_dir, template_contrast)
if dixon_files:
    original_affine = nib.load(dixon_files[0]).affine
    print(f"Using affine from: {dixon_files[0]}")
else:
    original_affine = np.eye(4)
    print("Warning: Using identity affine matrix")

# Create output subdirectories
registered_dir = os.path.join(output_dir, "registered_images")
masks_dir = os.path.join(output_dir, "warped_masks")
flow_dir = os.path.join(output_dir, "flow_fields")
os.makedirs(registered_dir, exist_ok=True)
os.makedirs(masks_dir, exist_ok=True)
os.makedirs(flow_dir, exist_ok=True)

print(f"\nOutput directories created:")
print(f"  Registered images: {registered_dir}")
print(f"  Warped masks: {masks_dir}")
print(f"  Flow fields: {flow_dir}")

# =============================================================================
# Save registered images for each contrast
# =============================================================================
print("\nSaving registered images...")

for i, contrast in enumerate(valid_contrasts):
    print(f"  Processing {contrast}...")
    
    # Non-rigid registered image
    # nonrigid_warped shape: (32, 6, 256, 256, 1) -> (D, num_moving, H, W, C)
    # Extract for specific contrast across all slices: nonrigid_warped[:, i, :, :, 0]
    img_data = nonrigid_warped[:, i, :, :, 0]  # (D, H, W)
    print(f"    Non-rigid shape: {img_data.shape}")
    
    # Inverse the preprocessing transformations for saving
    # Original orientation: after loading we did np.flip(np.rot90(data, -1), 1)
    # So to inverse: np.flip(np.rot90(data, 1), 0)
    img_data_save = np.transpose(img_data, (1, 2, 0))  # (H, W, D)
    img_data_save = np.flip(np.rot90(img_data_save, 1), 0)
    
    save_nifti(
        img_data_save,
        original_affine,
        os.path.join(registered_dir, f"{contrast}_registered_to_DIXON.nii.gz")
    )
    print(f"    Saved: {contrast}_registered_to_DIXON.nii.gz")
    
    # Affine registered image (optional)
    img_data_affine = affine_warped[:, i, :, :, 0]  # (D, H, W)
    img_data_affine_save = np.transpose(img_data_affine, (1, 2, 0))
    img_data_affine_save = np.flip(np.rot90(img_data_affine_save, 1), 0)
    
    save_nifti(
        img_data_affine_save,
        original_affine,
        os.path.join(registered_dir, f"{contrast}_affine_to_DIXON.nii.gz")
    )
    print(f"    Saved: {contrast}_affine_to_DIXON.nii.gz")

# =============================================================================
# Save warped masks
# =============================================================================
if template_mask is not None and 'warped_masks_stack' in locals() and warped_masks_stack is not None:
    print("\nSaving warped masks...")
    
    # warped_masks_stack shape: (6, 32, 256, 256) - (num_moving, D, H, W)
    for i, contrast in enumerate(valid_contrasts):
        print(f"  Processing mask for {contrast}...")
        
        mask_data = warped_masks_stack[i]  # (D, H, W)
        print(f"    Mask shape: {mask_data.shape}")
        
        # Inverse the preprocessing transformations for saving
        mask_data_save = np.transpose(mask_data, (1, 2, 0))  # (H, W, D)
        mask_data_save = np.flip(np.rot90(mask_data_save, 1), 0)
        
        save_nifti(
            mask_data_save,
            original_affine,
            os.path.join(masks_dir, f"DIXON_mask_warped_to_{contrast}.nii.gz")
        )
        print(f"    Saved: DIXON_mask_warped_to_{contrast}.nii.gz")
else:
    print("\nNo warped masks to save")

# =============================================================================
# Save flow fields (optional, for debugging)
# =============================================================================
if flow_fields is not None:
    print("\nSaving flow fields...")
    
    # flow_fields shape: (32, 6, 256, 256, 2) - (D, num_moving, H, W, 2)
    for i, contrast in enumerate(valid_contrasts):
        print(f"  Processing flow for {contrast}...")
        
        # Extract flow for this contrast across all slices
        flow = flow_fields[:, i, :, :, :]  # (D, H, W, 2)
        print(f"    Flow shape: {flow.shape}")
        
        # Save x and y components separately
        flow_x = flow[..., 0]  # (D, H, W)
        flow_y = flow[..., 1]  # (D, H, W)
        
        # Inverse transformations for saving
        flow_x_save = np.transpose(flow_x, (1, 2, 0))  # (H, W, D)
        flow_x_save = np.flip(np.rot90(flow_x_save, 1), 0)
        
        flow_y_save = np.transpose(flow_y, (1, 2, 0))
        flow_y_save = np.flip(np.rot90(flow_y_save, 1), 0)
        
        save_nifti(
            flow_x_save,
            original_affine,
            os.path.join(flow_dir, f"flow_x_{contrast}.nii.gz")
        )
        save_nifti(
            flow_y_save,
            original_affine,
            os.path.join(flow_dir, f"flow_y_{contrast}.nii.gz")
        )
        print(f"    Saved: flow_x_{contrast}.nii.gz, flow_y_{contrast}.nii.gz")

# =============================================================================
# Summary
# =============================================================================
print("\n" + "="*60)
print("SAVING COMPLETE")
print("="*60)
print(f"All results saved to: {output_dir}")
print(f"\nDirectory structure:")
print(f"  {registered_dir}/")
for contrast in valid_contrasts:
    print(f"    ├── {contrast}_registered_to_DIXON.nii.gz")
    print(f"    └── {contrast}_affine_to_DIXON.nii.gz")

if template_mask is not None and 'warped_masks_stack' in locals() and warped_masks_stack is not None:
    print(f"\n  {masks_dir}/")
    for contrast in valid_contrasts:
        print(f"    └── DIXON_mask_warped_to_{contrast}.nii.gz")

if flow_fields is not None:
    print(f"\n  {flow_dir}/")
    for contrast in valid_contrasts:
        print(f"    ├── flow_x_{contrast}.nii.gz")
        print(f"    └── flow_y_{contrast}.nii.gz")

print("\n✓ Done!")

## 12. Summary

This notebook demonstrated:

1. **Loading pre-trained registration models** (affine + non-rigid)
2. **Loading DIXON template and its segmentation mask**
3. **Loading all moving contrasts** from the resampled data structure
4. **Applying affine registration** for coarse alignment
5. **Applying non-rigid registration** for fine-grained alignment and obtaining flow fields
6. **Propagating the DIXON mask** to all other contrasts using the computed deformations
7. **Visualizing and saving results**

### Output Files

All results are saved in: `{output_dir}`

- `registered_images/` - Registered images for each contrast
- `warped_masks/` - DIXON mask warped to each contrast
- `flow_fields/` - Deformation fields (if saved)
- `registration_results.png` - Visualization of registration results
- `warped_masks.png` - Visualization of warped masks