In [None]:
#this is the code that was used on the ACCESS ondemand 11/06/2025 DJ
import nibabel as nib
import os
from pathlib import Path
from ants import plot, image_read, resample_image, registration, image_write, get_mask, mask_image
import ants
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os

# Generate masks for the fixed image
def generate_mask_all(aparc_aseg_path):
    # Select regions of interest as in the Tian parcellation
    
    # The specific labels to select, as shown in the image:
    # Labels of interest (from Freesurfer/aseg.mgz), matching subcortical structures:
    # Left hemisphere/Right hemisphere
    # - 17/53: hippocampus
    # - 18/54: amygdala
    # - 10/49: thalamus (pThalamus & aThalamus, both mapped to 10)
    # - 26/58: nucleus accumbens (Nac)
    # - 13/52: globus pallidus (GP)
    # - 12/51: putamen
    # - 11/50: caudate
    # aparc_aseg_path = os.path.join("/ocean/projects/bio220042p/djung2/data/exvivo_mod/I38_new_confidence/mri", "aparc+aseg.downsampled.nii")
    tmp = image_read(aparc_aseg_path)
    tmp_np = tmp.numpy()
    
    # Subcortical labels of interest (Freesurfer aseg.mgz conventions; see e.g., https://surfer.nmr.mgh.harvard.edu/fswiki/LookUpTable)
    labels_subcortical_lh = [17, 18, 10, 26, 13, 12, 11]  # Left hippocampus, amygdala, thalamus, accumbens, pallidus, putamen, caudate
    labels_subcortical_rh = [53, 54, 49, 58, 52, 51, 50]  # Right hippocampus, amygdala, thalamus, accumbens, pallidus, putamen, caudate
    labels_subcortical = labels_subcortical_lh + labels_subcortical_rh
    
    # Cortex labels in aseg are 1000-1035 (lh) and 2000-2035 (rh)
    labels_cortex_lh = list(range(1000, 1036))
    labels_cortex_rh = list(range(2000, 2036))
    labels_cortex = labels_cortex_lh + labels_cortex_rh
    
    # Create masks
    mask_sub = np.isin(tmp_np, labels_subcortical)
    mask_cortex = np.isin(tmp_np, labels_cortex)
    
    # Combine both masks (True for voxels that are in either mask)
    mask_sub_cortex_combined = mask_sub | mask_cortex
    
    # Masked arrays (zeros elsewhere)
    masked_sub_cortex_array = np.where(mask_sub_cortex_combined, tmp_np, 0)
    
    # Convert back to ANTs images
    masked_sub_cortex_ants = ants.from_numpy(
        masked_sub_cortex_array.astype(tmp_np.dtype),
        origin=tmp.origin,
        spacing=tmp.spacing,
        direction=tmp.direction
    )
    
    # Mask for the region of interest
    mask_all = masked_sub_cortex_ants > 0
    
    return mask_all

def resample(image_in, image_out, target_spacing, save_flag):
    img_in = image_read(str(image_in))
    img_resampled = resample_image(img_in, target_spacing, use_voxels=False, interp_type=1)
    # Save 
    if save_flag is True:
        image_write(img_resampled, str(image_out))
    return img_resampled

def register(fixed_data, moving_data, save_path, save_flag):
    # Register
    mytx = registration(
        fixed=fixed_data,
        moving=moving_data,
        type_of_transform='SyN',
        reg_iterations=(100, 70, 50, 20),  # More iterations
        aff_iterations=(10000, 10000, 10000, 10000),  # Increased affine iterations
        verbose=False,
        outprefix=os.path.join(save_path, 'reg_')
    )

    if save_flag is True:
        image_write(mytx['warpedmovout'], os.path.join(save_path, 'warped_template.nii.gz'))
        
        if 'warpedfixout' in mytx:
            image_write(mytx['warpedfixout'], os.path.join(save_path, 'warped_fixed.nii.gz'))

    return mytx

# Start script
base_path = "/ocean/projects/bio220042p/djung2/data/exvivo_mod/"

parcellation_path = '/ocean/projects/bio220042p/djung2/template/tian/Schaefer2018_400Parcels_7Networks_order_Tian_Subcortex_S1_3T_MNI152NLin2009cAsym_1mm.nii.gz'

# Find all subjects in the folder
aparc_aseg_filename = "aparc+aseg.upsampled.nii"
# Get a list of all directories in base_path
subject_names = [name for name in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, name))]

# Iterate for each subject
for subject_name in subject_names:
    # Determine which hemisphere is available
    lh_path = os.path.join(base_path, subject_name, "surf", "lh.white.surf.gii")
    rh_path = os.path.join(base_path, subject_name, "surf", "rh.white.surf.gii")
    has_lh = os.path.exists(lh_path)
    has_rh = os.path.exists(rh_path)

    if has_lh and has_rh:
        available_hemis = ["lh", "rh"]
    elif has_lh:
        available_hemis = ["lh"]
    elif has_rh:
        available_hemis = ["rh"]
    else:
        available_hemis = []

    for hemi in available_hemis:
        # Base save path
        save_path = f"/ocean/projects/bio220042p/djung2/exvivo_tian_registration/{subject_name}"
        os.makedirs(save_path, exist_ok=True)

        # Ex vivo data (fixed)
        dat_path = f"/ocean/projects/bio220042p/djung2/data/exvivo_mod/{subject_name}/mri/native.nii"
        out_path = os.path.join(save_path, "orig_img_resampled.nii")

        if os.path.exists(out_path):
            orig_img_resampled = image_read(out_path)
        else:
            orig_img_resampled = resample(dat_path, out_path, (1.0, 1.0, 1.0), True)  # 1mm isotropic

        # Load template brain (moving); should be in the same space as the parcellation data (Tian parcellation)
        mni_template_path = f'/ocean/projects/bio220042p/djung2/template/mni_icbm152_nlin_asym_09c/template_individual_hemi_tian/mni_icbm152_t2_tal_nlin_asym_09c_masked_{hemi}_only.nii.gz'
        mni_template_img = image_read(mni_template_path)

        # Generate mask_all
        out_path = os.path.join(save_path, f"maskall_{hemi}.downsampled.nii")

        if os.path.exists(out_path):
            mask_all = image_read(out_path)
        else:
            aparc_aseg_path = os.path.join(f"{base_path}/{subject_name}/mri", aparc_aseg_filename)
            mask_all = generate_mask_all(aparc_aseg_path)
            mask_all = resample_image(mask_all, (1.0, 1.0, 1.0), use_voxels=False, interp_type=1)
            image_write(mask_all, out_path)

        # Apply the mask_all on the source image and perform registration
        out_path = os.path.join(save_path, 'from_t2_tian', hemi)
        os.makedirs(out_path, exist_ok=True)
        mytx = register(orig_img_resampled * mask_all, mni_template_img, out_path, True)  # this is with Tian mask on the fixed img

        # Load Tian parcellation and mask it using the template image 
        tian_img = image_read(parcellation_path)
        tian_img_masked = tian_img * (mni_template_img > 0)

        # Transform Tian parcellation using the same transformation matrix from template (moving) to image (fixed)
        mywarpedimage = ants.apply_transforms(
            fixed=orig_img_resampled * mask_all, 
            moving=tian_img_masked,
            interpolator='genericLabel', 
            transformlist=mytx['fwdtransforms']
        )
        out_path = os.path.join(save_path, 'from_t2_tian', hemi)
        warped_save_path = os.path.join(out_path, "parcellated_img.nii")
        image_write(mywarpedimage, warped_save_path)

        # Save images
        plot(mywarpedimage, axis=0, cmap='jet_r', filename=os.path.join(save_path, f'tian_{hemi}_registered.png'))
        plot(orig_img_resampled, axis=0, cmap='jet_r', filename=os.path.join(save_path, f'orig_img_{hemi}.png'))