# Group Brain Registration Using pirt

This notebook implements group registration for brain images from the MouseCity3 dataset using the pirt library with DiffeomorphicDemonsRegistration.

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import pirt
import tifffile
from pathlib import Path
from skimage import io, transform, exposure

## Configuration

In [None]:
# Data directory
data_dir = Path('/nearline/spruston/Boaz/DELTA/I2/20250111_IDISCO_MouseCity3/')

# Registration settings
registration_settings = {
    'grid_sampling_factor': 1,  # The grid sampling of the grid at the final level
    'scale_sampling': 20,       # The amount of iterations for each level
    'speed_factor': 2,          # The relative force of the transform
    'downscale_factor': 0.25,   # Downscale images to speed up registration
}

## Find all brain images

In [3]:
def find_brain_images(data_dir):
    """Find all ch0.tif files in the data directory
    
    Args:
        data_dir (Path): Path to the data directory
        
    Returns:
        dict: Dictionary mapping animal IDs to file paths
    """
    brain_images = {}
    
    # List all animal directories
    for animal_dir in data_dir.glob('ANM*'):
        if animal_dir.is_dir():
            # Check if itk/ch0.tif exists
            ch0_file = animal_dir / 'itk' / 'ch0.tif'
            if ch0_file.exists():
                animal_id = animal_dir.name
                brain_images[animal_id] = ch0_file
    
    return brain_images

# Find all brain images
brain_images = find_brain_images(data_dir)
print(f"Found {len(brain_images)} brain images:")
for animal_id, file_path in brain_images.items():
    print(f"  {animal_id}: {file_path}")

Found 6 brain images:
  ANM555974: /nearline/spruston/Boaz/DELTA/I2/20250111_IDISCO_MouseCity3/ANM555974/itk/ch0.tif
  ANM555975: /nearline/spruston/Boaz/DELTA/I2/20250111_IDISCO_MouseCity3/ANM555975/itk/ch0.tif
  ANM555976: /nearline/spruston/Boaz/DELTA/I2/20250111_IDISCO_MouseCity3/ANM555976/itk/ch0.tif
  ANM556543: /nearline/spruston/Boaz/DELTA/I2/20250111_IDISCO_MouseCity3/ANM556543/itk/ch0.tif
  ANM556544: /nearline/spruston/Boaz/DELTA/I2/20250111_IDISCO_MouseCity3/ANM556544/itk/ch0.tif
  ANM556545: /nearline/spruston/Boaz/DELTA/I2/20250111_IDISCO_MouseCity3/ANM556545/itk/ch0.tif


## Load and preprocess images

In [None]:
def load_and_preprocess_image(file_path, downscale_factor=0.25):
    """Load and preprocess a brain image
    
    Args:
        file_path (Path): Path to the image file
        downscale_factor (float): Factor to downscale the image
        
    Returns:
        np.ndarray: Preprocessed image
    """
    # Load image
    print(f"Loading {file_path}...")
    img = tifffile.imread(file_path)
    
    # Get image info
    print(f"  Original shape: {img.shape}, dtype: {img.dtype}")
    
    # For 3D images, take middle slice or maximum projection
    if len(img.shape) == 3:
        # Take middle slice for demonstration
        middle_slice = img.shape[0] // 2
        img = img[middle_slice]
        print(f"  Taking middle slice: {middle_slice}, new shape: {img.shape}")
        
        # Alternatively, use maximum projection
        # img = np.max(img, axis=0)
        # print(f"  Using maximum projection, new shape: {img.shape}")
    
    # Rescale to 0-1 float
    img = img.astype(np.float32)
    img = (img - img.min()) / (img.max() - img.min())
    
    # Downscale if needed
    if downscale_factor < 1.0:
        new_shape = (int(img.shape[0] * downscale_factor), int(img.shape[1] * downscale_factor))
        img = transform.resize(img, new_shape, anti_aliasing=True, preserve_range=True)
        print(f"  Downscaled to: {img.shape}")
    
    # Enhance contrast
    img = exposure.equalize_adapthist(img)
    
    return img

# Load a sample image to check
if brain_images:
    sample_id = list(brain_images.keys())[0]
    sample_img = load_and_preprocess_image(brain_images[sample_id], registration_settings['downscale_factor'])
    
    # Display the image
    plt.figure(figsize=(10, 8))
    plt.imshow(sample_img, cmap='gray')
    plt.title(f"Sample image: {sample_id}")
    plt.colorbar()
    plt.show()
else:
    print("No brain images found!")

## Load all brain images

In [None]:
def load_all_brain_images(brain_images, downscale_factor):
    """Load all brain images
    
    Args:
        brain_images (dict): Dictionary mapping animal IDs to file paths
        downscale_factor (float): Factor to downscale the images
        
    Returns:
        dict: Dictionary mapping animal IDs to preprocessed images
    """
    loaded_images = {}
    
    for animal_id, file_path in tqdm(brain_images.items(), desc="Loading images"):
        loaded_images[animal_id] = load_and_preprocess_image(file_path, downscale_factor)
    
    return loaded_images

# Load all brain images
loaded_images = load_all_brain_images(brain_images, registration_settings['downscale_factor'])

## Visualize all brain images

In [None]:
def visualize_brain_images(loaded_images):
    """Visualize all brain images
    
    Args:
        loaded_images (dict): Dictionary mapping animal IDs to preprocessed images
    """
    num_images = len(loaded_images)
    cols = min(3, num_images)
    rows = (num_images + cols - 1) // cols
    
    plt.figure(figsize=(15, 5 * rows))
    
    for i, (animal_id, img) in enumerate(loaded_images.items()):
        plt.subplot(rows, cols, i+1)
        plt.imshow(img, cmap='gray')
        plt.title(animal_id)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize all brain images
visualize_brain_images(loaded_images)

## Register brain images using pirt

In [None]:
def register_brain_images(images, settings):
    """Register brain images using pirt's DiffeomorphicDemonsRegistration
    
    Args:
        images (dict): Dictionary mapping animal IDs to preprocessed images
        settings (dict): Registration settings
        
    Returns:
        tuple: (registration object, deformation fields, transformed images)
    """
    # Convert dictionary to list in a consistent order
    animal_ids = list(images.keys())
    image_list = [images[animal_id] for animal_id in animal_ids]
    
    # Create registration object
    print(f"Registering {len(image_list)} images...")
    reg = pirt.DiffeomorphicDemonsRegistration(*image_list)
    
    # Set registration parameters
    reg.params.grid_sampling_factor = settings['grid_sampling_factor']
    reg.params.scale_sampling = settings['scale_sampling']
    reg.params.speed_factor = settings['speed_factor']
    
    # Perform registration
    reg.register(verbose=1)
    
    # Get deformation fields and transformed images
    deforms = []
    transformed_images = {}
    
    for i, animal_id in enumerate(animal_ids):
        # Get deformation field
        deform = reg.get_deform(i)
        deforms.append(deform)
        
        # Transform image
        transformed = deform.apply_deformation(image_list[i])
        transformed_images[animal_id] = transformed
    
    return reg, deforms, transformed_images

# Register brain images
reg, deforms, transformed_images = register_brain_images(loaded_images, registration_settings)

## Visualize registration results

In [None]:
def visualize_registration_results(original_images, transformed_images):
    """Visualize registration results
    
    Args:
        original_images (dict): Dictionary mapping animal IDs to original images
        transformed_images (dict): Dictionary mapping animal IDs to transformed images
    """
    num_images = len(original_images)
    animal_ids = list(original_images.keys())
    
    plt.figure(figsize=(15, 5 * num_images))
    
    for i, animal_id in enumerate(animal_ids):
        # Original image
        plt.subplot(num_images, 3, i*3+1)
        plt.imshow(original_images[animal_id], cmap='gray')
        plt.title(f"{animal_id} - Original")
        plt.axis('off')
        
        # Transformed image
        plt.subplot(num_images, 3, i*3+2)
        plt.imshow(transformed_images[animal_id], cmap='gray')
        plt.title(f"{animal_id} - Registered")
        plt.axis('off')
        
        # Difference image
        plt.subplot(num_images, 3, i*3+3)
        diff = np.abs(original_images[animal_id] - transformed_images[animal_id])
        plt.imshow(diff, cmap='hot', vmin=0, vmax=0.5)
        plt.title(f"{animal_id} - Difference")
        plt.axis('off')
        plt.colorbar()
    
    plt.tight_layout()
    plt.show()

# Visualize registration results
visualize_registration_results(loaded_images, transformed_images)

## Calculate average brain template

In [None]:
def calculate_average_template(transformed_images):
    """Calculate average brain template from registered images
    
    Args:
        transformed_images (dict): Dictionary mapping animal IDs to transformed images
        
    Returns:
        np.ndarray: Average brain template
    """
    # Stack all transformed images
    image_stack = np.stack(list(transformed_images.values()))
    
    # Calculate average
    template = np.mean(image_stack, axis=0)
    
    return template

# Calculate average template
template = calculate_average_template(transformed_images)

# Display template
plt.figure(figsize=(10, 8))
plt.imshow(template, cmap='gray')
plt.title("Average Brain Template")
plt.colorbar()
plt.show()

## Save registration results

In [None]:
def save_registration_results(transformed_images, template, output_dir):
    """Save registration results
    
    Args:
        transformed_images (dict): Dictionary mapping animal IDs to transformed images
        template (np.ndarray): Average brain template
        output_dir (Path): Path to the output directory
    """
    # Create output directory if it doesn't exist
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save template
    template_path = output_dir / 'template.tif'
    tifffile.imwrite(template_path, template)
    print(f"Saved template to {template_path}")
    
    # Save transformed images
    for animal_id, img in transformed_images.items():
        img_path = output_dir / f"{animal_id}_registered.tif"
        tifffile.imwrite(img_path, img)
        print(f"Saved {animal_id} to {img_path}")

# Create output directory
output_dir = data_dir / 'registration_results'

# Save registration results
save_registration_results(transformed_images, template, output_dir)