In [1]:
#n.p check input at native or standard, check mask as native or standard, check time series, check parameters too of course

In [None]:
import os
import pandas as pd
import numpy as np
import nibabel as nib
from nilearn import image
from nilearn.glm.first_level import compute_regressor
import logging
from brainiak.searchlight.searchlight import Searchlight
from statsmodels.tsa.stattools import grangercausalitytests
import sys
import gc
import time

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Import your parameters
curr_dir = '/user_data/csimmon2/git_repos/ptoc'
sys.path.insert(0, curr_dir)
import ptoc_params as params

# Set up directories and parameters
study = 'ptoc'
study_dir = f"/lab_data/behrmannlab/vlad/{study}"
localizer = 'Scramble'  # scramble or object. This is the localizer task.
results_dir = '/user_data/csimmon2/git_repos/ptoc/results'
raw_dir = params.raw_dir

# Load subject information
sub_info = pd.read_csv(f'{curr_dir}/sub_info.csv')
sub_info = sub_info[sub_info['group'] == 'control']
subs = sub_info['sub'].tolist()
# For testing, you can uncomment the line below to include specific subjects
# subs = ['sub-025']

run_num = 3
runs = list(range(1, run_num + 1))
run_combos = [[rn1, rn2] for rn1 in range(1, run_num + 1) for rn2 in range(rn1 + 1, run_num + 1)]

# Searchlight parameters
searchlight_radius = 2  # in voxels, adjust as needed
max_blk_edge = 10
pool_size = 1

# Constants
VOL_PER_RUN = 184
TR = 2.0

def load_and_prepare_data(sub, run_combo):
    logging.info(f"Loading data for subject {sub}, runs {run_combo}")

    # Load and combine run data
    run_data_list = []
    for run in run_combo:
        run_file = f'{raw_dir}/{sub}/ses-01/derivatives/fsl/loc/run-0{run}/1stLevel.feat/filtered_func_data_reg.nii.gz'
        run_img = image.load_img(run_file)
        run_data = image.clean_img(run_img, standardize='zscore_sample')
        run_data_list.append(run_data)
    
    # Concatenate run data
    fmri_data = image.concat_imgs(run_data_list)
    
    # Load new whole brain mask
    mask_file = f'/user_data/csimmon2/git_repos/analyses/roiParcels/mruczek_parcels/binary/all_visual_areas.nii.gz'
    whole_brain_mask = nib.load(mask_file).get_fdata().astype(bool)

    # Load pIPS mask
    pips_mask_file = f'/user_data/csimmon2/git_repos/ptoc/roiParcels/pIPS.nii.gz' #not sure if this is sufficient to direct to file
    
    pips_mask = nib.load(pips_mask_file).get_fdata().astype(bool)

    # Generate psychological covariate
    psy = make_psy_cov(run_combo, sub)

    # Ensure fMRI data and psy covariate have the same number of time points
    if fmri_data.shape[-1] != len(psy):
        raise ValueError(f"Mismatch in volumes: fMRI data has {fmri_data.shape[-1]}, psy has {len(psy)}")
    
    # Convert fmri_data to 4D numpy array if it's not already
    if isinstance(fmri_data, nib.Nifti1Image):
        fmri_data = fmri_data.get_fdata()
    
    # Ensure fmri_data is 4D
    if fmri_data.ndim != 4:
        raise ValueError(f"fMRI data must be 4D, but got shape {fmri_data.shape}")
    
    logging.info(f"Data shapes - fMRI: {fmri_data.shape}, Whole Brain Mask: {whole_brain_mask.shape}, pIPS Mask: {pips_mask.shape}, PSY: {psy.shape}")
    
    # Check if dimensions match
    if fmri_data.shape[:3] != whole_brain_mask.shape:
        raise ValueError(f"Whole brain mask dimensions {whole_brain_mask.shape} do not match fMRI data dimensions {fmri_data.shape[:3]}")
    
    return fmri_data, whole_brain_mask, pips_mask, psy

def make_psy_cov(runs, ss):
    """
    Create psychological covariate data for the specified runs and subject.
    """
    temp_dir = f'{raw_dir}/{ss}/ses-01'
    cov_dir = f'{temp_dir}/covs'
    total_vols = VOL_PER_RUN * len(runs)
    times = np.arange(0, total_vols * TR, TR)
    full_cov = pd.DataFrame(columns=['onset', 'duration', 'value'])

    for i, rn in enumerate(runs):
        ss_num = ss.split('-')[1]
        obj_cov_file = f'{cov_dir}/catloc_{ss_num}_run-0{rn}_{localizer}.txt'

        if not os.path.exists(obj_cov_file):
            logging.warning(f'Covariate file not found for run {rn}')
            return np.zeros((total_vols, 1))  # Return a zeros array if file not found

        obj_cov = pd.read_csv(obj_cov_file, sep='\t', header=None, names=['onset', 'duration', 'value'])
        
        if i > 0:
            obj_cov['onset'] += i * VOL_PER_RUN * TR
        
        full_cov = pd.concat([full_cov, obj_cov])

    full_cov = full_cov.sort_values(by=['onset']).reset_index(drop=True)
    cov = full_cov.to_numpy()
    valid_onsets = cov[:, 0] < times[-1]
    cov = cov[valid_onsets]

    if cov.shape[0] == 0:
        logging.warning('No valid covariate data after filtering. Returning zeros array.')
        return np.zeros((total_vols, 1))

    psy, _ = compute_regressor(cov.T, 'spm', times)
    psy[psy > 0] = 1
    psy[psy <= 0] = 0
    return psy

def extract_pips_timeseries(fmri_data, pips_mask):
    """
    Extract the mean time series from the pIPS region.
    """
    fmri_2d = fmri_data.reshape(fmri_data.shape[3], -1)
    pips_ts = fmri_2d[:, pips_mask.flatten()]
    return np.mean(pips_ts, axis=1)

def gca_measure(data, mask, myrad, bcvar):
    """
    Perform Granger Causality Analysis on the searchlight sphere.
    """
    pips_ts, psy = bcvar
    
    # Reshape data to 2D: (125, time_points)
    data_2d = data[0].reshape(-1, data[0].shape[-1])
    
    # Apply mask to get the searchlight sphere time series
    sphere_ts = data_2d[mask.flatten()].mean(axis=0)
    
    # Ensure all time series have the same length
    min_length = min(sphere_ts.shape[0], pips_ts.shape[0], psy.shape[0])
    sphere_ts = sphere_ts[:min_length]
    pips_ts = pips_ts[:min_length]
    psy = psy[:min_length]
    
    # Perform Granger Causality tests
    gc_pips_to_sphere = grangercausalitytests(np.column_stack((sphere_ts, pips_ts, psy.flatten())), maxlag=1, verbose=False)
    gc_sphere_to_pips = grangercausalitytests(np.column_stack((pips_ts, sphere_ts, psy.flatten())), maxlag=1, verbose=False)
    
    # Calculate the difference in F-statistics
    f_diff = gc_pips_to_sphere[1][0]['ssr_ftest'][0] - gc_sphere_to_pips[1][0]['ssr_ftest'][0]
    
    return f_diff

def run_searchlight(fmri_data, whole_brain_mask, pips_mask, psy):
    """
    Run searchlight analysis on the fMRI data.
    """
    assert whole_brain_mask.ndim == 3 and whole_brain_mask.dtype == bool, "Invalid whole_brain_mask"
    assert fmri_data.shape[0:3] == whole_brain_mask.shape, "Whole brain mask dimensions do not match fMRI data."

    # Extract pIPS time series before searchlight
    pips_ts = extract_pips_timeseries(fmri_data, pips_mask)

    sl = Searchlight(sl_rad=searchlight_radius, max_blk_edge=max_blk_edge)
    fmri_data_4d = fmri_data.transpose(3, 0, 1, 2)
    
    sl.distribute([fmri_data_4d], whole_brain_mask)
    
    # Broadcast pips_ts and psy as a tuple
    sl.broadcast((pips_ts, psy))

    sl_result = sl.run_searchlight(gca_measure, pool_size=pool_size)
    
    return sl_result

if __name__ == "__main__":
    for sub in subs:
        for run_combo in run_combos:
            try:
                fmri_data, whole_brain_mask, pips_mask, psy = load_and_prepare_data(sub, run_combo)
                logging.info(f"Processing subject {sub}, runs {run_combo}")
                logging.info(f"Data shapes - fMRI: {fmri_data.shape}, Whole Brain Mask: {whole_brain_mask.shape}, pIPS Mask: {pips_mask.shape}, PSY: {psy.shape}")
                
                sl_result = run_searchlight(fmri_data, whole_brain_mask, pips_mask, psy)
                
                result_file = f'{results_dir}/searchlight_gca_{sub}_runs{"_".join(map(str, run_combo))}.nii.gz'
                affine = nib.load(f'{raw_dir}/{sub}/ses-01/derivatives/fsl/loc/run-01/1stLevel.feat/filtered_func_data_reg.nii.gz').affine
                nib.save(nib.Nifti1Image(sl_result, affine), result_file)
                
                logging.info(f"Completed analysis for subject {sub}, runs {run_combo}")
                
            except Exception as e:
                logging.error(f"Error processing subject {sub}, runs {run_combo}: {str(e)}")
            finally:
                gc.collect()

2024-10-07 23:36:07,023 - INFO - Loading data for subject sub-025, runs [1, 2]
2024-10-07 23:42:10,936 - ERROR - Error processing subject sub-025, runs [1, 2]: No such file or no access: '/user_data/csimmon2/git_repos/analyses/roiParcels/mruczek_parcels/binary/all_visual_areas.nii.gz'
2024-10-07 23:42:13,727 - INFO - Loading data for subject sub-025, runs [1, 3]
2024-10-07 23:48:12,783 - ERROR - Error processing subject sub-025, runs [1, 3]: No such file or no access: '/user_data/csimmon2/git_repos/analyses/roiParcels/mruczek_parcels/binary/all_visual_areas.nii.gz'
2024-10-07 23:48:14,976 - INFO - Loading data for subject sub-025, runs [2, 3]
2024-10-07 23:54:13,910 - ERROR - Error processing subject sub-025, runs [2, 3]: No such file or no access: '/user_data/csimmon2/git_repos/analyses/roiParcels/mruczek_parcels/binary/all_visual_areas.nii.gz'
2024-10-07 23:54:16,087 - INFO - Loading data for subject sub-038, runs [1, 2]
2024-10-08 00:00:15,682 - ERROR - Error processing subject sub-