# 02. Registration to MNI Space

This notebook registers subject 7T data to MNI standard space using ANTs SyN.

## Why Registration Matters

The LC atlas is defined in MNI space. To extract signal from the LC region, we must:
1. Register subject T1w â†’ MNI template
2. Apply the same transform to all contrast maps (R1, R2*, QSM)

Brainstem registration is challenging due to:
- Small structure sizes
- Lower contrast in deep brain regions
- Susceptibility artifacts near air-tissue interfaces

We use ANTs SyN (Symmetric Normalization) which is well-suited for subcortical alignment.

In [None]:
import sys
sys.path.append('../')
import os
import ants
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

from src.io import find_bids_files, load_nifti
from src.visualization import plot_slice_visualization

print(f"ANTsPy version: {ants.__version__}")

## 1. Load MNI Template

We use the MNI152 ICBM 2009c template for registration.

In [None]:
from nilearn.datasets import fetch_icbm152_2009

# Fetch template
mni = fetch_icbm152_2009()
mni_template_path = mni['t1']
print(f"MNI template: {mni_template_path}")

# Load with ANTs
fixed = ants.image_read(mni_template_path)
print(f"Template shape: {fixed.shape}")
print(f"Template spacing: {fixed.spacing}")

## 2. Setup Paths

In [None]:
DATA_DIR = '../data'
OUTPUT_DIR = '../outputs/results'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Find subjects
subjects = sorted([d for d in os.listdir(DATA_DIR) if d.startswith('sub-')])
print(f"Found {len(subjects)} subjects")

# Contrasts to transform (after T1w registration)
CONTRAST_SUFFIXES = {
    'R1': ['R1map.nii.gz', 'R1.nii.gz'],
    'R2star': ['R2starmap.nii.gz', 'R2star.nii.gz'],
    'QSM': ['QSM.nii.gz', 'Chimap.nii.gz'],
    'T2starw': ['T2starw.nii.gz']
}

## 3. Register Single Subject (Development)

First, we develop the registration on one subject before batch processing.

In [None]:
def find_contrast_file(data_dir, sub_id, patterns):
    """Find a file matching any of the given patterns."""
    sub_dir = os.path.join(data_dir, sub_id)
    for root, dirs, files in os.walk(sub_dir):
        for f in files:
            for pattern in patterns:
                if f.endswith(pattern):
                    return os.path.join(root, f)
    return None

def register_subject(sub_id, fixed_img, data_dir, output_dir):
    """
    Register subject T1w to MNI and apply transforms to all contrast maps.
    
    Returns dict with paths to all warped images.
    """
    results = {'subject_id': sub_id, 'success': False}
    
    # Find T1w
    t1w_path = find_contrast_file(data_dir, sub_id, ['T1w.nii.gz', 'UNIT1.nii.gz'])
    if not t1w_path:
        print(f"  T1w not found for {sub_id}")
        return results
    
    # Load moving image
    moving = ants.image_read(t1w_path)
    print(f"  T1w shape: {moving.shape}, spacing: {moving.spacing}")
    
    # Create subject output directory
    sub_out_dir = os.path.join(output_dir, sub_id)
    os.makedirs(sub_out_dir, exist_ok=True)
    
    # Perform registration
    print(f"  Running ANTs SyN registration...")
    registration = ants.registration(
        fixed=fixed_img,
        moving=moving,
        type_of_transform='SyN',
        syn_metric='CC',        # Cross-correlation for structural
        syn_sampling=4,
        reg_iterations=(100, 70, 50, 20),
        verbose=False
    )
    
    # Save warped T1w
    t1w_mni_path = os.path.join(sub_out_dir, f'{sub_id}_T1w_MNI.nii.gz')
    ants.image_write(registration['warpedmovout'], t1w_mni_path)
    results['T1w_MNI'] = t1w_mni_path
    print(f"  Saved warped T1w: {t1w_mni_path}")
    
    # Apply transforms to other contrasts
    for contrast_name, patterns in CONTRAST_SUFFIXES.items():
        contrast_path = find_contrast_file(data_dir, sub_id, patterns)
        
        if contrast_path:
            contrast_img = ants.image_read(contrast_path)
            
            # Apply same transforms
            warped = ants.apply_transforms(
                fixed=fixed_img,
                moving=contrast_img,
                transformlist=registration['fwdtransforms'],
                interpolator='linear'
            )
            
            # Save
            output_path = os.path.join(sub_out_dir, f'{sub_id}_{contrast_name}_MNI.nii.gz')
            ants.image_write(warped, output_path)
            results[f'{contrast_name}_MNI'] = output_path
            print(f"  Saved {contrast_name}: {output_path}")
        else:
            print(f"  {contrast_name} not found")
    
    results['success'] = True
    results['transforms'] = registration['fwdtransforms']
    
    return results

In [None]:
# Process first subject as test
if subjects:
    test_sub = subjects[0]
    print(f"Processing {test_sub}...")
    
    results = register_subject(test_sub, fixed, DATA_DIR, OUTPUT_DIR)
    
    if results['success']:
        print(f"\nRegistration successful!")
    else:
        print(f"\nRegistration failed.")

## 4. Visual QC

Check registration quality by overlaying warped T1w on template.

In [None]:
if results.get('success'):
    from nilearn import plotting
    
    # Load warped image
    warped_t1w = nib.load(results['T1w_MNI'])
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Axial at brainstem level (z ~ -25 in MNI)
    for i, (cut_coords, title) in enumerate([
        ((-25,), 'Brainstem (z=-25)'),
        ((-30,), 'Pons (z=-30)'),
        ((-35,), 'Lower Pons (z=-35)')
    ]):
        display = plotting.plot_anat(
            warped_t1w,
            cut_coords=cut_coords,
            display_mode='z',
            axes=axes[i],
            title=title,
            draw_cross=False
        )
    
    plt.suptitle(f"{test_sub} - Warped to MNI (Brainstem Slices)", y=1.02)
    plt.tight_layout()
    
    qc_path = f'../outputs/figures/{test_sub}_registration_qc.png'
    plt.savefig(qc_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"QC figure saved: {qc_path}")

## 5. Batch Processing (Optional)

Process all subjects. This will take significant time.

In [None]:
# Uncomment to run batch processing
# WARNING: This will take a long time (10-30 min per subject)

# all_results = []
# 
# for i, sub_id in enumerate(subjects):
#     print(f"\n[{i+1}/{len(subjects)}] Processing {sub_id}...")
#     
#     # Skip if already processed
#     sub_out_dir = os.path.join(OUTPUT_DIR, sub_id)
#     if os.path.exists(os.path.join(sub_out_dir, f'{sub_id}_T1w_MNI.nii.gz')):
#         print(f"  Already processed, skipping.")
#         continue
#     
#     results = register_subject(sub_id, fixed, DATA_DIR, OUTPUT_DIR)
#     all_results.append(results)
#     
#     # Save progress
#     import json
#     with open(os.path.join(OUTPUT_DIR, 'registration_log.json'), 'w') as f:
#         json.dump(all_results, f, indent=2, default=str)
# 
# print(f"\nProcessed {len(all_results)} subjects.")
# print(f"Successful: {sum(r['success'] for r in all_results)}")

## 6. Summary

This notebook:
1. Registered T1w images to MNI152 space using ANTs SyN
2. Applied the same transforms to R1, R2*, and QSM maps
3. Generated QC figures for visual inspection

**Next**: Apply LC atlas to extract signal (Notebook 03)