In [None]:
import os
import glob
import numpy as np
import nibabel as nib
import pandas as pd
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
from nipype.interfaces import afni 
from joblib import Parallel, delayed
import logging
import uuid

os.environ["PATH"] += os.pathsep + os.path.expanduser("~/abin")

In [None]:
# Paths to data
data_path = "/Users/labneuro2/Documents/lab/SBvsMB4/halfpipe"

# Regressors to remove
REGRESSORS_TO_USE = [
    "trans_x", "trans_y", "trans_z", "rot_x", "rot_y", "rot_z",
    "a_comp_cor_00", "a_comp_cor_01", "a_comp_cor_02", "a_comp_cor_03", "a_comp_cor_04"
]

# Bandpass filters and their corresponding folder names
FILTER_SETTINGS = {"high_pass": 0.008, "low_pass": 0.09, "label": "bp_008_090"}

FD_THRESHOLDS = [0.4, 1.0]

# Paths to mask and ROI atlas
atlases_path = "/Users/labneuro2/Documents/lab/SBvsMB4/atlases"
parcels = ["100Parcels", "200Parcels", "400Parcels"]
tians = ["S1", "S2", "S4"]
atlas_filenames, output_paths = [], []
for parc, tian in zip(parcels, tians):
    atlas_filenames.append(f"{atlases_path}/Schaefer2018_{parc}_7Networks_order_Tian_Subcortex_{tian}_3T_MNI152NLin2009cAsym_2mm.nii.gz")
    atlas_name = f"schaefer_{parc}_Tian_{tian}"
    output_paths.append(f"/Users/labneuro2/Documents/lab/SBvsMB4/correlation_matrices/{atlas_name}")


# Find all NIfTI files
nifti_files = sorted(glob.glob(f"{data_path}/**/*preproc_bold.nii.gz", recursive=True))

# Temporary directory for intermediate files
temp_dir = "/Users/labneuro2/Documents/lab/SBvsMB4/temp_files"
os.makedirs(temp_dir, exist_ok=True)

In [None]:
def calculate_correlation_min(base_nifti_file, fmri_file, confound_file, high_pass, low_pass,
                              mask_file, atlas_filenames, output_paths, confound_vars,
                              temp_id, minutes, fd_threshold, cenmode='KILL'):

    import re
    logging.getLogger("nipype").setLevel(logging.CRITICAL)

    img = nib.load(fmri_file)
    tr = img.header.get_zooms()[3]

    # Load confounds and select valid columns
    confounds_df = pd.read_csv(confound_file, sep='\t')
    valid_confounds = [col for col in confound_vars if col in confounds_df.columns]
    confounds = confounds_df[valid_confounds].fillna(0).values

    # Save confounds to .1D file for AFNI
    confounds_1d_path = f"{temp_dir}/motion_regressors_{temp_id}.1D"
    np.savetxt(confounds_1d_path, confounds, fmt="%.6f")

    # Create censor file based on framewise displacement threshold
    censor_1d_path = f"{temp_dir}/censor_{temp_id}.1D"
    if "framewise_displacement" in confounds_df.columns:
        censor_series = (confounds_df["framewise_displacement"] < fd_threshold).astype(int)
    else:
        censor_series = np.ones(len(confounds_df))
    np.savetxt(censor_1d_path, censor_series.values, fmt="%d")

    # Set up AFNI TProject for temporal filtering and nuisance regression
    nifti_output_afni = f"{temp_dir}/filtered_fmri_{temp_id}.nii.gz"
    tproject = afni.TProject()
    tproject.inputs.in_file = fmri_file
    tproject.inputs.out_file = nifti_output_afni
    tproject.inputs.bandpass = (high_pass, low_pass)
    tproject.inputs.polort = 1
    tproject.inputs.ort = confounds_1d_path
    tproject.inputs.mask = mask_file
    tproject.inputs.censor = censor_1d_path
    tproject.inputs.cenmode = cenmode

    try:
        print(f"🖥️ AFNI 3dTproject CMD: {tproject.cmdline}")
        res = tproject.run()
        stderr = res.runtime.stderr

        # Parse degrees of freedom (DOF) from AFNI output
        dof_match = re.search(r"==>\s+(\d+)\s+D\.O\.F\. left", stderr)
        if dof_match:
            afni_dof = int(dof_match.group(1))
            print(f"DOF from AFNI: {afni_dof}")
        else:
            afni_dof = None
            print("Could not extract DOF from AFNI stderr.")
    except Exception as e:
        print(f"Error running AFNI 3dTproject: {e}")
        return None, None

    # Load filtered data
    img_filtered = nib.load(nifti_output_afni)

    for atlas_filename, output_path in zip(atlas_filenames, output_paths):
        # Extract time series using atlas
        masker = NiftiLabelsMasker(labels_img=atlas_filename, mask_img=mask_file, standardize=True, t_r=tr)
        time_series = masker.fit_transform(img_filtered)

        # Compute correlation matrix
        connectivity_measure = ConnectivityMeasure(kind='correlation')
        correlation_matrix = connectivity_measure.fit_transform([time_series])[0]

        # Save correlation matrix to CSV
        new_filename_csv = base_nifti_file.split('/')[-1].replace(
            "setting-preproc_bold.nii.gz",
            f"trimmed_{minutes}min_correlation.csv"
        )
        output_file_csv = os.path.join(output_path, new_filename_csv)
        os.makedirs(os.path.dirname(output_file_csv), exist_ok=True)
        np.savetxt(output_file_csv, correlation_matrix, delimiter=",")

    # Clean up temporary files
    for f in [nifti_output_afni, confounds_1d_path, censor_1d_path]:
        if os.path.exists(f):
            os.remove(f)

    return correlation_matrix, afni_dof


In [None]:
def process_nifti(nifti_file, temp_id):
    confounds_file = nifti_file.replace("bold.nii.gz", "desc-confounds_regressors.tsv")
    mask_file = nifti_file.replace("bold", "desc-brain_mask")
    dof_records = []
    img = nib.load(nifti_file)
    data = img.get_fdata()
    affine = img.affine
    header = img.header
    tr = header.get_zooms()[3]
    total_volumes = data.shape[3]

    confounds_df = pd.read_csv(confounds_file, sep='\t', na_values='n/a')

    for minutes in range(5, 14):
        target_volumes = int(minutes * 60 / tr)
        if target_volumes > total_volumes:
            target_volumes = total_volumes

        trimmed_data = data[:, :, :, :target_volumes]
        trimmed_img = nib.Nifti1Image(trimmed_data, affine, header)
        trimmed_confounds = confounds_df.iloc[:target_volumes]

        for fd_threshold in FD_THRESHOLDS:
            for setting in FILTER_SETTINGS:
                fd_label = f"fd_{int(fd_threshold * 1000):03d}"
                setting_label = f"{FILTER_SETTINGS['label']}_{fd_label}"
                unique_id = uuid.uuid4().hex[:6]
                temp_suffix = f"{temp_id}_{minutes}_{setting_label}"

                # Temporary files for trimmed fMRI and confounds
                trimmed_fmri_file = f"{temp_dir}/trimmed_fmri_{temp_suffix}.nii.gz"
                trimmed_confounds_filename = f"{temp_dir}/trimmed_confounds_{temp_suffix}.tsv"

                nib.save(trimmed_img, trimmed_fmri_file)
                trimmed_confounds.to_csv(trimmed_confounds_filename, sep='\t', index=False)

                # Create output directories for each atlas
                for path in output_paths:
                    os.makedirs(path, exist_ok=True)

                correlation_matrix, afni_dof = calculate_correlation_min(
                    nifti_file,
                    trimmed_fmri_file, trimmed_confounds_filename,
                    FILTER_SETTINGS["high_pass"], FILTER_SETTINGS["low_pass"],
                    mask_file, atlas_filenames, output_paths_setting,
                    REGRESSORS_TO_USE,
                    temp_suffix, minutes,
                    fd_threshold=fd_threshold
                )

                dof_records.append({
                    "file": os.path.basename(nifti_file),
                    "minutes": minutes,
                    "volumes": target_volumes,
                    "fd_threshold": fd_threshold,
                    "afni_dof": afni_dof
                })

                # Clean up temporary files
                for f in [trimmed_fmri_file, trimmed_confounds_filename]:
                    if os.path.exists(f):
                        os.remove(f)

    return dof_records


In [None]:
results = Parallel(n_jobs=48)(
    delayed(process_nifti)(nifti_file, i)
    for i, nifti_file in enumerate(nifti_files)
)

In [None]:
# Flatten the list of results into a single list
flattened_dof_records = [item for sublist in results for item in sublist]

# Save the results to a CSV file
dof_df = pd.DataFrame(flattened_dof_records)
dof_df.to_csv("dof_summary.csv", index=False)