In [2]:
import os
import sys
import pandas as pd
import numpy as np
import nibabel as nib
from nilearn import image, input_data
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Import your parameters
curr_dir = f'/user_data/csimmon2/git_repos/ptoc'
sys.path.insert(0, curr_dir)
import ptoc_params as params

# Set up directories and parameters
study = 'ptoc'
study_dir = f"/lab_data/behrmannlab/vlad/{study}"
raw_dir = params.raw_dir
results_dir = '/user_data/csimmon2/git_repos/ptoc/results'

# Atlas paths
atlas_path = f'{curr_dir}/glasser/HCP-MMP1_on_MNI152_ICBM2009a_nlin.nii.gz'
labels_path = f'{curr_dir}/glasser/HCP-MMP1_on_MNI152_ICBM2009a_nlin.txt'

# Load subject information
sub_info = pd.read_csv(f'{curr_dir}/sub_info.csv')
subjects_to_skip = []
subs = sub_info[(sub_info['group'] == 'control') & (~sub_info['sub'].isin(subjects_to_skip))]['sub'].tolist()

run_num = 3
runs = list(range(1, run_num + 1))

def read_glasser_labels(labels_path):
    """Read and parse the Glasser atlas labels file"""
    try:
        with open(labels_path, 'r') as f:
            lines = f.readlines()
            # Parse labels file - structure depends on the file format
            # We'll need to adjust this based on the actual content of your .txt file
            labels = [line.strip() for line in lines if line.strip()]
        logging.info(f"Loaded {len(labels)} region labels")
        return labels
    except Exception as e:
        logging.error(f"Error reading labels file: {e}")
        return None

def apply_glasser_atlas(subject_data, atlas_img):
    # Resample atlas to subject space
    resampled_atlas = image.resample_to_img(atlas_img, subject_data, interpolation='nearest')
    
    # Extract time series data
    atlas_masker = NiftiLabelsMasker(labels_img=resampled_atlas, 
                                    standardize=True,
                                    memory='nilearn_cache', 
                                    verbose=0)
    time_series = atlas_masker.fit_transform(subject_data)
    
    return time_series

def create_connectivity_matrix(ss):
    logging.info(f"Processing subject: {ss}")
    
    # Load Glasser atlas
    atlas_img = nib.load(atlas_path)
    logging.info(f"Loaded Glasser atlas with shape: {atlas_img.shape}")
    
    all_runs_data = []
    
    for rn in runs:
        # Load subject data for this run
        run_path = f'{raw_dir}/{ss}/ses-01/derivatives/fsl/loc/run-0{rn}/1stLevel.feat/filtered_func_data_reg.nii.gz'
        if not os.path.exists(run_path):
            logging.warning(f'Run data not found: {run_path}')
            continue
        
        subject_img = nib.load(run_path)
        logging.info(f"Loaded run {rn} with shape: {subject_img.shape}")
        
        # Apply atlas and extract time series
        run_data = apply_glasser_atlas(subject_img, atlas_img)
        logging.info(f"Extracted time series with shape: {run_data.shape}")
        all_runs_data.append(run_data)
    
    if not all_runs_data:
        logging.warning(f'No valid run data found for subject {ss}')
        return None
    
    # Concatenate runs
    full_time_series = np.concatenate(all_runs_data, axis=0)
    
    # Compute connectivity matrix
    correlation_measure = ConnectivityMeasure(kind='correlation')
    connectivity_matrix = correlation_measure.fit_transform([full_time_series])[0]
    logging.info(f"Created {connectivity_matrix.shape} connectivity matrix")
    
    return connectivity_matrix

def calculate_group_matrices():
    all_matrices = []
    
    for ss in subs:
        matrix_path = f'{results_dir}/connectivity_matrices/{ss}_glasser_connectivity_matrix.npy'
        if os.path.exists(matrix_path):
            matrix = np.load(matrix_path)
            all_matrices.append(matrix)
    
    if all_matrices:
        # Calculate mean and std matrices
        all_matrices = np.array(all_matrices)
        mean_matrix = np.mean(all_matrices, axis=0)
        std_matrix = np.std(all_matrices, axis=0)
        
        # Save group-level matrices
        output_dir = f'{results_dir}/connectivity_data'
        os.makedirs(output_dir, exist_ok=True)
        np.save(f'{output_dir}/glasser_group_mean_matrix.npy', mean_matrix)
        np.save(f'{output_dir}/glasser_group_std_matrix.npy', std_matrix)
        logging.info('Saved group-level matrices')

def main():
    # Load atlas labels
    labels = read_glasser_labels(labels_path)
    if labels is None:
        logging.warning("Proceeding without labels")
    
    # Create individual connectivity matrices
    for ss in subs:
        connectivity_matrix = create_connectivity_matrix(ss)
        if connectivity_matrix is not None:
            # Save the connectivity matrix
            output_dir = f'{results_dir}/connectivity_matrices'
            os.makedirs(output_dir, exist_ok=True)
            np.save(f'{output_dir}/{ss}_glasser_connectivity_matrix.npy', connectivity_matrix)
            logging.info(f'Saved connectivity matrix for {ss}')
    
    # Calculate and save group-level matrices
    calculate_group_matrices()

if __name__ == "__main__":
    main()

2024-10-25 17:52:46,303 - INFO - Loaded 180 region labels
2024-10-25 17:52:46,304 - INFO - Processing subject: sub-057
2024-10-25 17:52:46,307 - INFO - Loaded Glasser atlas with shape: (197, 233, 189)
2024-10-25 17:52:46,588 - INFO - Loaded run 1 with shape: (176, 256, 256, 184)


KeyboardInterrupt: 