In [19]:
# Imports and configuration
import os
import sys
import pandas as pd
import ptoc_params as params
import numpy as np
import nibabel as nib
from nilearn import datasets, image, plotting
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
from nilearn.glm.first_level import compute_regressor
from scipy.stats import norm
from statsmodels.stats.multitest import fdrcorrection
import logging
import matplotlib.pyplot as plt

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set up directories and parameters
curr_dir = '/user_data/csimmon2/git_repos/ptoc'
study_dir = "/lab_data/behrmannlab/vlad/ptoc"
raw_dir = params.raw_dir
results_dir = f'{curr_dir}/results'
roi_dir = f'{curr_dir}/roiParcels'

# Load subject information
sub_info = pd.read_csv(f'{curr_dir}/sub_info.csv')
subjects_to_skip = ['sub-084']
subs = sub_info[(sub_info['group'] == 'control') & (~sub_info['sub'].isin(subjects_to_skip))]['sub'].tolist()

run_num = 3
runs = list(range(1, run_num + 1))

# FDR parameters
alpha = 0.05  # FDR alpha level
two_sided = False  # Only keep positive correlations

In [20]:
# Create Merged Atlas Function
def create_merged_atlas():
    """Create a merged atlas where Wang ROIs replace overlapping regions in Schaefer atlas"""
    logging.info("Creating merged atlas...")
    
    # Load Wang ROIs
    roi_files = {
        'pIPS': f'{roi_dir}/pIPS.nii.gz',
        'LO': f'{roi_dir}/LO.nii.gz'
    }
    
    rois = {}
    for roi_name, roi_path in roi_files.items():
        if os.path.exists(roi_path):
            rois[roi_name] = nib.load(roi_path)
            logging.info(f"Loaded {roi_name} ROI")
        else:
            logging.error(f"ROI file {roi_path} not found!")
            return None
    
    # Load Schaefer atlas
    n_rois = 200  # Schaefer atlas ROIs
    atlas = datasets.fetch_atlas_schaefer_2018(n_rois=n_rois, yeo_networks=7, resolution_mm=2)
    atlas_img = nib.load(atlas.maps)
    atlas_labels = atlas.labels
    logging.info(f"Loaded Schaefer atlas with {len(atlas_labels)} parcels")
    
    # Get atlas data
    atlas_data = atlas_img.get_fdata()
    modified_atlas_data = atlas_data.copy()
    
    # Create a dictionary to store new labels
    new_labels = list(atlas_labels)
    
    # Assign values for new ROIs (continuing from the end of the Schaefer atlas)
    roi_values = {'pIPS': 201, 'LO': 202}
    overlap_info = {}
    
    # Process each ROI
    for roi_name, roi_img in rois.items():
        # Get ROI data and create mask
        roi_data = roi_img.get_fdata()
        roi_mask = roi_data > 0
        
        # Find overlapping parcels
        overlap_mask = (atlas_data > 0) & roi_mask
        overlapping_labels = np.unique(atlas_data[overlap_mask])
        overlapping_labels = overlapping_labels[overlapping_labels > 0]
        
        # Get number of voxels in overlap
        overlap_voxels = {}
        for label in overlapping_labels:
            label_mask = (atlas_data == label) & roi_mask
            overlap_voxels[int(label)] = np.sum(label_mask)
        
        # Store overlap information
        overlap_info[roi_name] = {
            'overlapping_labels': overlapping_labels.tolist(),
            'overlap_voxels': overlap_voxels
        }
        
        logging.info(f"{roi_name} overlaps with {len(overlapping_labels)} atlas parcels")
        for label, voxels in overlap_voxels.items():
            label_idx = int(label) - 1  # Convert to 0-indexed
            if 0 <= label_idx < len(atlas_labels):
                label_name = atlas_labels[label_idx]
                label_name = label_name.decode('utf-8') if isinstance(label_name, bytes) else str(label_name)
                logging.info(f"  Label {label} ({label_name}): {voxels} voxels")
        
        # Remove overlapping parcels from the atlas
        for label in overlapping_labels:
            label_mask = (modified_atlas_data == label) & roi_mask
            modified_atlas_data[label_mask] = 0
        
        # Add ROI with new label
        modified_atlas_data[roi_mask] = roi_values[roi_name]
        
        # Add new label name
        new_labels.append(f"Wang_{roi_name}")
    
    # Create the modified atlas
    merged_atlas_img = nib.Nifti1Image(modified_atlas_data, atlas_img.affine, atlas_img.header)
    merged_atlas_file = f'{results_dir}/schaefer_wang_merged.nii.gz'
    nib.save(merged_atlas_img, merged_atlas_file)
    logging.info(f"Saved merged atlas to: {merged_atlas_file}")
    
    # Save new labels array
    labels_file = f'{results_dir}/merged_atlas_labels.npy'
    np.save(labels_file, new_labels)
    logging.info(f"Saved merged labels to: {labels_file}")
    
    return merged_atlas_img, new_labels

In [15]:
# utility functions 
def verify_standard_space(img):
    """Verify image is in 2mm standard space"""
    if img.shape[:3] != (91, 109, 91):
        logging.warning(f"Unexpected shape: {img.shape}")
        return False
    
    vox_size = np.sqrt(np.sum(img.affine[:3, :3] ** 2, axis=0))
    if not np.allclose(vox_size, [2., 2., 2.], atol=0.1):
        logging.warning(f"Unexpected voxel size: {vox_size}")
        return False
    
    return True

def get_condition_mask(run_num, ss, condition, n_timepoints):
    """Create a binary mask for timepoints during a specific condition"""
    cov_dir = f'{raw_dir}/{ss}/ses-01/covs'
    ss_num = ss.split('-')[1]
    
    # Load condition timing file
    cov_file = f'{cov_dir}/catloc_{ss_num}_run-0{run_num}_{condition}.txt'
    if not os.path.exists(cov_file):
        logging.warning(f'Covariate file not found: {cov_file}')
        return np.zeros(n_timepoints, dtype=bool)
    
    # Load timing data
    cov = pd.read_csv(cov_file, sep='\t', header=None, 
                     names=['onset', 'duration', 'value'])
    
    # Create timepoints array
    tr = 2.0  # TR in seconds
    times = np.arange(0, n_timepoints * tr, tr)
    
    # Convert timing to binary mask
    condition_reg, _ = compute_regressor(cov.to_numpy().T, 'spm', times)
    
    # Convert to binary mask and ensure it's 1D
    return (condition_reg > 0).ravel()

In [16]:
# Create and Threshold Connectivity Matrix
def create_and_threshold_connectivity_matrix(ss, condition, merged_atlas_img):
    """Create and FDR-threshold connectivity matrix for specific condition"""
    logging.info(f"Processing subject {ss} for condition {condition}")
    
    all_runs_data = []
    
    for rn in runs:
        # Load standard space data
        run_path = f'{raw_dir}/{ss}/ses-01/derivatives/reg_standard/filtered_func_run-0{rn}_standard.nii.gz'
        
        if not os.path.exists(run_path):
            logging.warning(f'Standard space data not found: {run_path}')
            continue
        
        subject_img = nib.load(run_path)
        
        # Verify standard space
        if not verify_standard_space(subject_img):
            logging.warning(f"Data not in expected standard space for {ss} run-{rn}")
            continue
        
        # Extract time series using merged atlas
        masker = NiftiLabelsMasker(
            labels_img=merged_atlas_img,
            standardize='zscore_sample',
            memory=None,
            verbose=0
        )
        
        time_series = masker.fit_transform(subject_img)
        logging.info(f"Time series shape before masking: {time_series.shape}")
        
        # Get condition mask
        condition_mask = get_condition_mask(rn, ss, condition, time_series.shape[0])
        logging.info(f"Condition mask shape: {condition_mask.shape}")
        
        # Only keep timepoints during condition
        masked_time_series = time_series[condition_mask]
        logging.info(f"Time series shape after masking: {masked_time_series.shape}")
        
        if masked_time_series.shape[0] > 0:  # Only append if we have data
            all_runs_data.append(masked_time_series)
    
    if not all_runs_data:
        logging.warning(f'No valid data found for subject {ss} condition {condition}')
        return None
    
    # Concatenate runs
    full_time_series = np.concatenate(all_runs_data, axis=0)
    logging.info(f"Full time series shape: {full_time_series.shape}")
    
    # Compute connectivity matrix using ConnectivityMeasure (as in your original code)
    correlation_measure = ConnectivityMeasure(
        kind='correlation',
        standardize='zscore_sample'  # Z-score the time series
    )
    connectivity_matrix = correlation_measure.fit_transform([full_time_series])[0]
    
    # Save uncorrected matrix for reference
    output_dir = f'{results_dir}/connectivity_merged_{condition.lower()}'
    os.makedirs(output_dir, exist_ok=True)
    np.save(f'{output_dir}/{ss}_connectivity_uncorrected_{condition.lower()}.npy', connectivity_matrix)
    
    # Convert correlations to Z-scores using Fisher's Z-transform
    # This is similar to z-scoring step in the threshold script
    z_matrix = np.arctanh(connectivity_matrix)
    
    # Convert Z-scores to p-values (one or two-tailed test)
    if two_sided:
        p_matrix = 2 * (1 - norm.cdf(np.abs(z_matrix)))
    else:
        p_matrix = 1 - norm.cdf(z_matrix)
    
    # Apply FDR correction (similar to threshold_stats_img with height_control='fdr')
    # Extract upper triangle (excluding diagonal) for correction
    mask = np.triu(np.ones(z_matrix.shape), k=1).astype(bool)
    p_values = p_matrix[mask]
    
    # Apply FDR correction
    significant, _ = fdrcorrection(p_values, alpha=alpha)
    logging.info(f"FDR thresholding: {np.sum(significant)} of {len(p_values)} connections significant at α={alpha}")
    
    # Create thresholded matrix
    thresholded_matrix = np.zeros_like(connectivity_matrix)
    thresholded_matrix[mask] = connectivity_matrix[mask] * significant
    
    # Make symmetric (as correlation matrices are symmetric)
    thresholded_matrix = thresholded_matrix + thresholded_matrix.T
    np.fill_diagonal(thresholded_matrix, 1.0)  # Set diagonal to 1.0
    
    # Similar to threshold script: zero out negative values if needed
    if not two_sided:
        thresholded_matrix[thresholded_matrix < 0] = 0
    
    return thresholded_matrix

In [None]:
# execution script with debugging
# Debug the script execution
print(f"Working directory: {os.getcwd()}")
print(f"Results directory exists: {os.path.exists(results_dir)}")
print(f"ROI directory exists: {os.path.exists(roi_dir)}")
print(f"Number of subjects: {len(subs)}")

# Check if ROI files exist
roi_files = {
    'pIPS': f'{roi_dir}/pIPS.nii.gz',
    'LO': f'{roi_dir}/LO.nii.gz'
}
for name, path in roi_files.items():
    print(f"ROI file {name} exists: {os.path.exists(path)}")

# Try to run the first few steps
try:
    print("Creating merged atlas...")
    merged_atlas_img, merged_labels = create_merged_atlas()
    if merged_atlas_img is None:
        print("Failed to create merged atlas!")
    else:
        print("Merged atlas created successfully")
        
        # Test with one subject
        test_sub = subs[0] if subs else None
        if test_sub:
            print(f"Testing with subject {test_sub}...")
            result = create_and_threshold_connectivity_matrix(test_sub, "Object", merged_atlas_img)
            print(f"Test result: {'Success' if result is not None else 'Failed'}")
except Exception as e:
    print(f"Error: {str(e)}")
    import traceback
    traceback.print_exc()

Working directory: /user_data/csimmon2/git_repos/ptoc
Results directory exists: True
ROI directory exists: True
Number of subjects: 18
ROI file pIPS exists: True
ROI file LO exists: True
Creating merged atlas...
Merged atlas created successfully
Testing with subject sub-025...
