In [19]:
# Imports and configuration
import os
import sys
import pandas as pd
import ptoc_params as params
import numpy as np
import nibabel as nib
from nilearn import datasets, image, plotting
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
from nilearn.glm.first_level import compute_regressor
from scipy.stats import norm
from statsmodels.stats.multitest import fdrcorrection
import logging
import matplotlib.pyplot as plt

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set up directories and parameters
curr_dir = '/user_data/csimmon2/git_repos/ptoc'
study_dir = "/lab_data/behrmannlab/vlad/ptoc"
raw_dir = params.raw_dir
results_dir = f'{curr_dir}/results'
roi_dir = f'{curr_dir}/roiParcels'

# Load subject information
sub_info = pd.read_csv(f'{curr_dir}/sub_info.csv')
subjects_to_skip = ['sub-084']
subs = sub_info[(sub_info['group'] == 'control') & (~sub_info['sub'].isin(subjects_to_skip))]['sub'].tolist()

run_num = 3
runs = list(range(1, run_num + 1))

# FDR parameters
alpha = 0.05  # FDR alpha level
two_sided = False  # Only keep positive correlations

In [20]:
# Create Merged Atlas Function
def create_merged_atlas():
    """Create a merged atlas where Wang ROIs replace overlapping regions in Schaefer atlas"""
    logging.info("Creating merged atlas...")
    
    # Load Wang ROIs
    roi_files = {
        'pIPS': f'{roi_dir}/pIPS.nii.gz',
        'LO': f'{roi_dir}/LO.nii.gz'
    }
    
    rois = {}
    for roi_name, roi_path in roi_files.items():
        if os.path.exists(roi_path):
            rois[roi_name] = nib.load(roi_path)
            logging.info(f"Loaded {roi_name} ROI")
        else:
            logging.error(f"ROI file {roi_path} not found!")
            return None
    
    # Load Schaefer atlas
    n_rois = 200  # Schaefer atlas ROIs
    atlas = datasets.fetch_atlas_schaefer_2018(n_rois=n_rois, yeo_networks=7, resolution_mm=2)
    atlas_img = nib.load(atlas.maps)
    atlas_labels = atlas.labels
    logging.info(f"Loaded Schaefer atlas with {len(atlas_labels)} parcels")
    
    # Get atlas data
    atlas_data = atlas_img.get_fdata()
    modified_atlas_data = atlas_data.copy()
    
    # Create a dictionary to store new labels
    new_labels = list(atlas_labels)
    
    # Assign values for new ROIs (continuing from the end of the Schaefer atlas)
    roi_values = {'pIPS': 201, 'LO': 202}
    overlap_info = {}
    
    # Process each ROI
    for roi_name, roi_img in rois.items():
        # Get ROI data and create mask
        roi_data = roi_img.get_fdata()
        roi_mask = roi_data > 0
        
        # Find overlapping parcels
        overlap_mask = (atlas_data > 0) & roi_mask
        overlapping_labels = np.unique(atlas_data[overlap_mask])
        overlapping_labels = overlapping_labels[overlapping_labels > 0]
        
        # Get number of voxels in overlap
        overlap_voxels = {}
        for label in overlapping_labels:
            label_mask = (atlas_data == label) & roi_mask
            overlap_voxels[int(label)] = np.sum(label_mask)
        
        # Store overlap information
        overlap_info[roi_name] = {
            'overlapping_labels': overlapping_labels.tolist(),
            'overlap_voxels': overlap_voxels
        }
        
        logging.info(f"{roi_name} overlaps with {len(overlapping_labels)} atlas parcels")
        for label, voxels in overlap_voxels.items():
            label_idx = int(label) - 1  # Convert to 0-indexed
            if 0 <= label_idx < len(atlas_labels):
                label_name = atlas_labels[label_idx]
                label_name = label_name.decode('utf-8') if isinstance(label_name, bytes) else str(label_name)
                logging.info(f"  Label {label} ({label_name}): {voxels} voxels")
        
        # Remove overlapping parcels from the atlas
        for label in overlapping_labels:
            label_mask = (modified_atlas_data == label) & roi_mask
            modified_atlas_data[label_mask] = 0
        
        # Add ROI with new label
        modified_atlas_data[roi_mask] = roi_values[roi_name]
        
        # Add new label name
        new_labels.append(f"Wang_{roi_name}")
    
    # Create the modified atlas
    merged_atlas_img = nib.Nifti1Image(modified_atlas_data, atlas_img.affine, atlas_img.header)
    merged_atlas_file = f'{results_dir}/schaefer_wang_merged.nii.gz'
    nib.save(merged_atlas_img, merged_atlas_file)
    logging.info(f"Saved merged atlas to: {merged_atlas_file}")
    
    # Save new labels array
    labels_file = f'{results_dir}/merged_atlas_labels.npy'
    np.save(labels_file, new_labels)
    logging.info(f"Saved merged labels to: {labels_file}")
    
    return merged_atlas_img, new_labels

In [15]:
# utility functions 
def verify_standard_space(img):
    """Verify image is in 2mm standard space"""
    if img.shape[:3] != (91, 109, 91):
        logging.warning(f"Unexpected shape: {img.shape}")
        return False
    
    vox_size = np.sqrt(np.sum(img.affine[:3, :3] ** 2, axis=0))
    if not np.allclose(vox_size, [2., 2., 2.], atol=0.1):
        logging.warning(f"Unexpected voxel size: {vox_size}")
        return False
    
    return True

def get_condition_mask(run_num, ss, condition, n_timepoints):
    """Create a binary mask for timepoints during a specific condition"""
    cov_dir = f'{raw_dir}/{ss}/ses-01/covs'
    ss_num = ss.split('-')[1]
    
    # Load condition timing file
    cov_file = f'{cov_dir}/catloc_{ss_num}_run-0{run_num}_{condition}.txt'
    if not os.path.exists(cov_file):
        logging.warning(f'Covariate file not found: {cov_file}')
        return np.zeros(n_timepoints, dtype=bool)
    
    # Load timing data
    cov = pd.read_csv(cov_file, sep='\t', header=None, 
                     names=['onset', 'duration', 'value'])
    
    # Create timepoints array
    tr = 2.0  # TR in seconds
    times = np.arange(0, n_timepoints * tr, tr)
    
    # Convert timing to binary mask
    condition_reg, _ = compute_regressor(cov.to_numpy().T, 'spm', times)
    
    # Convert to binary mask and ensure it's 1D
    return (condition_reg > 0).ravel()

In [None]:
# create connectivity matrices with merged atlas
def create_connectivity_matrices():
    """Create unthresholded connectivity matrices for each subject and condition"""
    # Create merged atlas
    merged_atlas_img, _ = create_merged_atlas()
    
    for condition in ['Object', 'Scramble']:
        output_dir = f'{results_dir}/connectivity_merged_{condition.lower()}'
        os.makedirs(output_dir, exist_ok=True)
        
        for ss in subs:
            # Extract time series data
            all_runs_data = []
            for rn in runs:
                run_path = f'{raw_dir}/{ss}/ses-01/derivatives/reg_standard/filtered_func_run-0{rn}_standard.nii.gz'
                if not os.path.exists(run_path): continue
                
                masker = NiftiLabelsMasker(labels_img=merged_atlas_img, standardize='zscore_sample')
                time_series = masker.fit_transform(run_path)
                condition_mask = get_condition_mask(rn, ss, condition, time_series.shape[0])
                masked_time_series = time_series[condition_mask]
                
                if masked_time_series.shape[0] > 0:
                    all_runs_data.append(masked_time_series)
            
            if not all_runs_data: continue
            
            # Calculate correlation matrix
            full_time_series = np.concatenate(all_runs_data, axis=0)
            correlation_measure = ConnectivityMeasure(kind='correlation', standardize='zscore_sample')
            conn_matrix = correlation_measure.fit_transform([full_time_series])[0]
            
            # Save unthresholded matrix
            output_path = f'{output_dir}/{ss}_connectivity_{condition.lower()}.npy'
            np.save(output_path, conn_matrix)
            print(f"Saved unthresholded matrix for {ss}, condition {condition}")

In [29]:
# main function

def main():
    """Main function to create connectivity matrices with parcel-based FDR correction"""
    # Create merged atlas
    merged_atlas_img, merged_labels = create_merged_atlas()
    
    conditions = ['Object', 'Scramble']
    alpha = 0.05  # FDR threshold
    
    for condition in conditions:
        output_dir = f'{results_dir}/connectivity_merged_{condition.lower()}_parcel_fdr'
        os.makedirs(output_dir, exist_ok=True)
        
        for ss in subs:
            try:
                # Process subject data
                all_runs_data = []
                for rn in runs:
                    run_path = f'{raw_dir}/{ss}/ses-01/derivatives/reg_standard/filtered_func_run-0{rn}_standard.nii.gz'
                    if not os.path.exists(run_path):
                        continue
                    
                    masker = NiftiLabelsMasker(labels_img=merged_atlas_img, standardize='zscore_sample')
                    time_series = masker.fit_transform(run_path)
                    condition_mask = get_condition_mask(rn, ss, condition, time_series.shape[0])
                    masked_time_series = time_series[condition_mask]
                    
                    if masked_time_series.shape[0] > 0:
                        all_runs_data.append(masked_time_series)
                
                if not all_runs_data:
                    continue
                
                # Calculate correlation matrix
                full_time_series = np.concatenate(all_runs_data, axis=0)
                correlation_measure = ConnectivityMeasure(kind='correlation', standardize='zscore_sample')
                conn_matrix = correlation_measure.fit_transform([full_time_series])[0]
                
                # Apply FDR correction based on parcels (not connections)
                z_matrix = np.arctanh(conn_matrix)
                z_matrix[np.isinf(z_matrix)] = np.finfo(float).max
                
                # For each parcel, correct its connections
                n_parcels = conn_matrix.shape[0]
                thresholded = np.zeros_like(conn_matrix)
                
                for i in range(n_parcels):
                    # Get p-values for this parcel's connections
                    p_values = 1 - norm.cdf(z_matrix[i, :])
                    
                    # Apply FDR
                    significant, _ = fdrcorrection(p_values, alpha=alpha)
                    
                    # Set significant connections
                    thresholded[i, :] = conn_matrix[i, :] * significant
                
                # Set diagonal to 1
                np.fill_diagonal(thresholded, 1.0)
                
                # Save the matrix
                output_path = f'{output_dir}/{ss}_connectivity_{condition.lower()}.npy'
                np.save(output_path, thresholded)
                logging.info(f"Saved parcel-FDR matrix for {ss}, condition {condition}")
                
            except Exception as e:
                logging.error(f"Error processing {ss}, condition {condition}: {str(e)}")

In [35]:
# Run the complete analysis
print(f"Starting full analysis for {len(subs)} subjects...")
main()
print("Analysis complete. Results saved to:")
print(f"1. Merged atlas: {results_dir}/schaefer_wang_merged.nii.gz")
print(f"2. FDR-thresholded matrices: {results_dir}/connectivity_merged_object_fdr/ and {results_dir}/connectivity_merged_scramble_fdr/")

Starting full analysis for 18 subjects...




Analysis complete. Results saved to:
1. Merged atlas: /user_data/csimmon2/git_repos/ptoc/results/schaefer_wang_merged.nii.gz
2. FDR-thresholded matrices: /user_data/csimmon2/git_repos/ptoc/results/connectivity_merged_object_fdr/ and /user_data/csimmon2/git_repos/ptoc/results/connectivity_merged_scramble_fdr/




In [None]:
# visualization
def visualize_group_connectivity():
    """Create group average connectivity matrices with white diagonal and difference map"""
    import seaborn as sns
    
    # Set up visualization directory
    viz_dir = f'{results_dir}/connectivity_visualizations'
    os.makedirs(viz_dir, exist_ok=True)
    
    # Dictionary to store group matrices
    group_matrices = {}
    
    # Process each condition
    for condition in ['object', 'scramble']:
        matrix_dir = f'{results_dir}/connectivity_merged_{condition}'
        
        # Find all matrix files
        matrix_files = [f for f in os.listdir(matrix_dir) if f.endswith('.npy')]
        print(f"Creating group average for '{condition}' from {len(matrix_files)} subjects")
        
        # Load and average matrices
        matrices = []
        for matrix_file in matrix_files:
            matrix_path = f'{matrix_dir}/{matrix_file}'
            matrix = np.load(matrix_path)
            # Keep only positive correlations
            matrix[matrix < 0] = 0
            matrices.append(matrix)
        
        # Create group average
        group_matrix = np.mean(matrices, axis=0)
        group_matrices[condition] = group_matrix.copy()  # Save a copy for the difference map
        
        # Set diagonal to NaN so it will appear white in the visualization
        np.fill_diagonal(group_matrix, np.nan)
        
        # Create heatmap
        plt.figure(figsize=(10, 8))
        sns.heatmap(group_matrix, cmap="Reds", vmin=0, vmax=1, square=True)
        plt.title(f"Group Average Connectivity Matrix - {condition.capitalize()}")
        
        # Save figure
        out_path = f"{viz_dir}/group_average_{condition}.png"
        plt.savefig(out_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved group average to {out_path}")
    
    # Create difference map (Object - Scramble)
    if 'object' in group_matrices and 'scramble' in group_matrices:
        diff_matrix = group_matrices['object'] - group_matrices['scramble']
        
        # Set diagonal to NaN for white display
        np.fill_diagonal(diff_matrix, np.nan)
        
        # Create heatmap with diverging colormap
        plt.figure(figsize=(10, 8))
        sns.heatmap(diff_matrix, cmap="RdBu_r", center=0, square=True)
        plt.title("Difference Map (Object - Scramble)")
        
        # Save figure
        out_path = f"{viz_dir}/difference_map_object_minus_scramble.png"
        plt.savefig(out_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved difference map to {out_path}")
    
    print(f"Group visualizations saved to {viz_dir}")
    return viz_dir

# Visualize connectivity matrices
visualize_group_connectivity()

Creating group average for 'object' from 18 subjects
Saved group average to /user_data/csimmon2/git_repos/ptoc/results/connectivity_visualizations/group_average_object.png
Creating group average for 'scramble' from 18 subjects
Saved group average to /user_data/csimmon2/git_repos/ptoc/results/connectivity_visualizations/group_average_scramble.png
Saved difference map to /user_data/csimmon2/git_repos/ptoc/results/connectivity_visualizations/difference_map_object_minus_scramble.png
Group visualizations saved to /user_data/csimmon2/git_repos/ptoc/results/connectivity_visualizations


'/user_data/csimmon2/git_repos/ptoc/results/connectivity_visualizations'