This notebook basically contains a summary of what the other scripts here do, in one place to bring out the exact next step that I'm trying to accomplish.

- First, I load a voxel- and parcel-wise timeseries from our data.
- Then, I create functional conncetomes from those (correlating each voxel to each parcel // parcels as the connectivity targets).
- I didn't copy the details here but we'll assess the reliability of these connectomes in a separate script.
- Now, I want to write code to derive parcel-wise SRMs from a subject-averaged connectome and concatenate transformation matrices to create a brain-wide transformation matrix for each subject. 

My main question is that once I have read in voxel data through a masker, I no longer know how to keep track of which voxels correspond to which labels in the Schaefer maps we are using here. Somewhat relatedly, I'm not sure exactly what to mask with, and if I should be using the Schaefer parcellation to do so.

The input files are resting state functional data, already fMRIPrepped and in mni152nlin2009casym_res-2 space. 


Other to-do (just notes for myself):
- potentially write searchlight loading to try searchlights as connectivity targets
- potentially use FCMA library to have voxels as our connectivity targets (Cython issues...)
- Reliability.py: permutation tests?
- Write task_decoding.py

In [None]:
import glob
import numpy as np
from nilearn import datasets
from nilearn.maskers import MultiNiftiMasker, MultiNiftiLabelsMasker, NiftiSpheresMasker
from scipy.stats import pearsonr
from math import tanh

In [None]:
def load_data(FILE_PATHS=[], 
            strategy = 'parcel', 
            schaefer_n_rois=400, 
            sphere_radius=8, 
            sphere_spacing=6):
    """
    This function loads functional data in 3 different ways:
    1. Strategy = voxel: extracts direct voxel timeseries
    2. Strategy = parcel: extracts parcel timeseries from schaefer 2018 atlas with a given # of ROIs (400 for now)
    3. (NOT IMPLEMENTED) Strategy = searchlight: extracts timeseries of spheres of a given radius and spacing

    NOTE to self: the masker.fit_transform functions return an array of the shape (n_TRs x n_voxels)
    """

    if FILE_PATHS == []:
        FILE_PATHS = glob.glob('data/rest/*.nii.gz') # all rest data in my current testing dir by default

    masker_args = {
        'standardize': 'zscore_sample', # ?
        'n_jobs': -1,
        # add: mask_img from fmriprep brain mask?
        # not doing any: smoothing, detrend, standardize, low_pass, high_pass, t_r
    }

    if strategy == 'voxel':
        masker = MultiNiftiMasker(**masker_args) # mask_strategy = 'whole-brain-template' or 'gm-template'?

    elif strategy == 'parcel':
        schaefer_atlas = datasets.fetch_atlas_schaefer_2018(n_rois=schaefer_n_rois, 
                                                        yeo_networks=7, 
                                                        resolution_mm=2, # because our data is too
                                                        data_dir='data/schaefer', 
                                                        verbose=0)
        masker = MultiNiftiLabelsMasker(labels_img = schaefer_atlas.maps,
                                labels = schaefer_atlas.labels,
                                resampling_target = 'data',
                                strategy = 'mean',
                                **masker_args)

    elif strategy == 'searchlight':
        sphere_coords = [] # get the centerpoint coordinates of spheres - these current numbers are incorrect
        # movie magic
        masker = NiftiSpheresMasker(seeds = sphere_coords, 
                                    radius=sphere_radius, 
                                    **masker_args)
        return [masker.fit_transform(f) for f in FILE_PATHS]

    # this only works for the multimaskers with voxel/parcel
    return masker.fit_transform(FILE_PATHS)

In [None]:
def correlate_rows(mat1, mat2, zscore=False):
    """ 
    Helper function for below
    Returns a matrix with the Pearson r correlation of each column (voxel) of mat1 with each column (parcel/target) of mat2
    """
    correlation_matrix = np.empty((mat1.shape[1], mat2.shape[1]))
    for i in range(mat1.shape[0]):
        for j in range(mat2.shape[0]):
            correlation_matrix[i, j] = pearsonr(mat1[:, i], mat2[:, j])[0]
            if zscore:
                # fisher transformation
                correlation_matrix[i,j] = tanh(correlation_matrix[i,j])
    return correlation_matrix

def compute_fc_target(voxel_timeseries, target_timeseries, zscore=True):
    """ 
    This will take each column in the voxel timeseries (across all the TRs/rows) 
    and correlate it with each column in the target timeseries.
    Connectivity targets could be the parcel timeseries, or a searchlight timeseries.

    Not using nilearn ConnectivityMeasure because this is across two matrices--couldn't figure that out
    
    NOTE: The connectivity target in which a given voxel resides is not excluded yet -- should it be?
    """
    return [correlate_rows(voxel_timeseries[i], target_timeseries[i], zscore) for i in range(len(voxel_timeseries))]

In [None]:
# load data, calculate connectome
voxel_data = load_data(strategy='voxel') 
parcel_data = load_data(strategy='parcel')
connectomes = compute_fc_target(voxel_data, parcel_data, zscore = True)
# there will be a step in here where the connectomes from each subject are averaged across sessions--ignore that for now

now here is my problem

i have connectomes of the shape (n_voxels x n_parcels) to derive SRMs from 
but, I have lost the ability to know which parcel each voxel belongs to which parcel.

The ideas that occur to me:
- If I better understood exactly what mask is applied to the data by the masker and how it's flattened, then I can resample the atlas to the original data shape/affine and then mask and flatten it in the same way.
- Should I be using the nilearn maskers at all? 


In [None]:
"""
mock-up of what I want to run:

schaefer_atlas = datasets.fetch_atlas_schaefer_2018(n_rois=400, 
                                                        yeo_networks=7, 
                                                        resolution_mm=2, # because our data is too
                                                        data_dir='data', 
                                                        verbose=0)
                                                        
label_map = nib.load(schaefer_atlas.maps)
shared_model = srm.SRM(n_iter=20, features=50)
transforms_list = []

for parcel in np.unique(label_map):
    parcel_idx = np.where(label_map==parcel)
    parcel_connectomes = [subj_connectome[parcel_idx] for subj_connectome in connectomes]   

    shared_model.fit(parcel_connectomes)   
    transforms_list.append(shared_model.w_) # save the transformation matrices to concatenate later
"""