In [None]:
import SimpleITK as sitk
import os
import glob
import argparse
from nilearn.datasets import fetch_icbm152_2009  # For the public fallback template

"""
Affine Registration Pipeline (Native Space -> MNI Space)
------------------------------------------------------
This script performs linear (affine) registration of skull-stripped brain images
to a standard template space.

Template Logic:
1. Priority: Uses the specialized 'JHU-MNI-SS' DWI template (Restricted Access).
2. Fallback: If the restricted template is not found, defaults to the standard
   MNI152 T1-weighted template (Open Source) for demonstration purposes.

Registration Parameters:
- Transform: Affine (12 Degrees of Freedom)
- Metric: Mattes Mutual Information (Robust for multi-modal registration)
- Optimizer: Gradient Descent
"""

# ==============================================================================
# CONFIGURATION
# ==============================================================================
# Users should place their restricted template here if they have access
RESTRICTED_TEMPLATE_PATH = "./templates/MuRadiolNormal_DWI_space-JHUMNI.nii.gz"

DEFAULT_INPUT_DIR = "./data/skullstripped"
DEFAULT_OUTPUT_DIR = "./data/registered_mni"

def get_registration_template():
    """
    Loads the registration template. Falls back to standard MNI152 if the
    specialized JHU template is unavailable.
    """
    if os.path.exists(RESTRICTED_TEMPLATE_PATH):
        print(f"✅ Loading specialized study template: {RESTRICTED_TEMPLATE_PATH}")
        return sitk.ReadImage(RESTRICTED_TEMPLATE_PATH, sitk.sitkFloat32)
    else:
        print(f"⚠️ Specialized template not found at {RESTRICTED_TEMPLATE_PATH}")
        print("   -> Downloading/Loading standard MNI152 template (FSL/Nilearn version)...")

        # Fetch standard MNI template
        mni = fetch_icbm152_2009()
        mni_path = mni['t1']  # Uses the T1 version

        return sitk.ReadImage(mni_path, sitk.sitkFloat32)

def register_images(input_dir, output_dir):

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")

    # 1. Load the fixed template
    try:
        fixed_image = get_registration_template()
    except Exception as e:
        print(f"ERROR: Could not load any registration template. {e}")
        return

    # 2. Find input files (Handles .nii and .nii.gz)
    search_pattern = os.path.join(input_dir, 'sub-*', '*_brain.nii*')
    file_list = glob.glob(search_pattern)

    if not file_list:
        print(f"ERROR: No NIfTI files found in {input_dir}")
        return

    print(f"\nFound {len(file_list)} files to process.")

    # 3. Batch Processing
    for moving_file_path in file_list:
        file_name = os.path.basename(moving_file_path)

        # Smart filename handling (removes .nii or .nii.gz)
        if file_name.endswith('.nii.gz'):
            base_name = file_name[:-7]
        elif file_name.endswith('.nii'):
            base_name = file_name[:-4]
        else:
            base_name = file_name

        output_file_name = f"{base_name}_space-MNI.nii.gz"
        output_path = os.path.join(output_dir, output_file_name)

        # Checkpoint: Skip if exists
        if os.path.exists(output_path):
            print(f"Skipping {file_name} (Already exists)")
            continue

        print(f"Processing: {file_name}...")

        try:
            moving_image = sitk.ReadImage(moving_file_path, sitk.sitkFloat32)

            # --- Registration Setup ---
            R = sitk.ImageRegistrationMethod()

            # Metric: Mattes MI is crucial here because the fallback template (T1)
            # might have different intensity distribution than input (DWI/FLAIR).
            R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)

            R.SetOptimizerAsGradientDescent(learningRate=0.1,
                                            numberOfIterations=500,
                                            convergenceMinimumValue=1e-6)

            # Initial Alignment (Center of Mass)
            initial_transform = sitk.CenteredTransformInitializer(
                fixed_image,
                moving_image,
                sitk.AffineTransform(fixed_image.GetDimension())
            )
            R.SetInitialTransform(initial_transform)
            R.SetInterpolator(sitk.sitkLinear)

            # Execute
            final_transform = R.Execute(fixed_image, moving_image)

            # Resample (Apply Transform)
            resampler = sitk.ResampleImageFilter()
            resampler.SetReferenceImage(fixed_image)
            resampler.SetInterpolator(sitk.sitkLinear)
            resampler.SetDefaultPixelValue(0)
            resampler.SetTransform(final_transform)

            warped_image = resampler.Execute(moving_image)

            # Save
            sitk.WriteImage(warped_image, output_path)
            print(f"   -> Saved: {output_file_name}")

        except Exception as e:
            print(f"   -> FAILED: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", default=DEFAULT_INPUT_DIR, help="Path to skull-stripped raw data")
    parser.add_argument("--output_dir", default=DEFAULT_OUTPUT_DIR, help="Path for registered output")
    args = parser.parse_args()

    register_images(args.input_dir, args.output_dir)