# Current To-Do:
- Contrast map comparison
- Conjunction analysis

Backburner questions:
- XCP-D?
- DiFuMo atlas instead of Schaefer?
- Different SRM distance penalties (distance as penalty instead of parcelwise? Searchlights instead of parcels?)

[ROADMAP DOC](https://docs.google.com/document/d/13P4QTHxrT5lZfCOXtN59xCKpJfnObtqh3uZkuRqPxR4/edit?pli=1#heading=h.2qncjqtc0b5j)

# Testing

In [None]:
# load all data (connectomes, task) and subjects
# loop through subjects
#   derive loso SRM transforms
#   transform task data
#   compute conjunction
# plot

In [None]:
import os, sys, glob, json, itertools
import numpy as np
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from nilearn.maskers import NiftiMasker
from utils import srm_brainiak as srm
from joblib import Parallel, delayed
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from connectivity import get_combined_mask
from srm import load_parcel_map
from utils.fastsrm_brainiak import FastSRM 

In [None]:
CONTRAST_PATH = '/oak/stanford/groups/russpold/data/network_grant/discovery_BIDS_21.0.1/derivatives/output_optcom_MNI/'
target_contrasts = {
    'cuedTS': 'cuedTS_contrast-task_switch_cost',
    'directedForgetting': 'directedForgetting_contrast-neg-con',
    'flanker': 'flanker_contrast-incongruent - congruent',
    'nBack': 'nBack_contrast-twoBack-oneBack',
    'spatialTS': 'spatialTS_contrast-task_switch_cost',
    'shapeMatching': 'shapeMatching_contrast-main_vars',
    'goNogo': 'goNogo_contrast-nogo_success-go',
    'stopSignal': 'stopSignal_contrast-stop_failure-go'
}

In [None]:
CONNECTOME_PATH = '/scratch/users/csiyer/connectomes/'
SRM_PATH = '/scratch/users/csiyer/srm_outputs/'

connectome_files = sorted(glob.glob('/scratch/users/csiyer/connectomes/*avg*'))
connectomes = [np.load(f) for f in connectome_files]
sub_list = [f[f.find('sub')+4:f.find('sub')+7] for f in connectome_files] # s03, s10, etc.
parcel_map = load_parcel_map(n_dimensions = 100)

def compute_loso_srm(data_list, sub_list, loso_sub, parcel_map, n_features=100, save=True):
    """
    This function uses BrainIAK's Shared Response Modeling function to compute parcel-wise SRMs (one per parcel, as an anatomical constraint).
    CRUCIALLY, the shared model is derived on all subjects but one; that left-out subject is subsequently added to derive their own transformation matrix.
    This is done for the purpose of eliminating data leakage when we later transform all other subjects' data into that subject's native space.

    I use joblib's Parallel and delayed functions to speed up the process.
    
    Inputs: 
        - list of subject-specific connectomes
        - list of subject names
        - which subject to leave out
        - parcel map matching the data dimensionality
        - how many dimensions in the shared model
        - whether to save the transformation matrices
    Output: transformation matrices for each subject

    NOTE: the parcelwise shared responses are not saved here, because I haven't been using them for anything.
    """
    outpath = os.path.join(SRM_PATH, f'loso/{loso_sub}/')
    if os.path.exists(outpath): # if this script has been run before, we can load past results instead of re-deriving
        return [np.load(glob.glob(outpath + f'sub-{sub}_srm_transform_loso-{loso_sub}.npy')[0]) for sub in sub_list]

    loso_idx = sub_list.index(loso_sub)
    loso_data = data_list[loso_idx]
    train_subs = [s for i,s in enumerate(sub_list) if i != loso_idx]
    train_data = [d for i,d in enumerate(data_list) if i != loso_idx]
    new_subject_list = [s for s in sub_list if s != loso_sub] + [loso_sub]

    def single_parcel_srm(train_data, loso_data, parcel_map, parcel_label, n_features):
        parcel_idxs = np.where(parcel_map == parcel_label)
        train_data_parcel = [d[parcel_idxs] for d in train_data]
        srm = FastSRM(n_components=n_features, n_iter=20, n_jobs=1, aggregate='mean')
        reduced_sr = srm.fit_transform(train_data)
        srm.aggregate = None
        srm.add_subjects(loso_data, reduced_sr)



    def single_parcel_srm(data_list, parcel_map, parcel_label, n_features):
        parcel_idx = np.where(parcel_map == parcel_label)
        data_parcel = [d[parcel_idx] for d in data_list]
        shared_model = srm.SRM(n_iter=20, features=n_features) 
        shared_model.fit(data_parcel)
        return shared_model.s_, shared_model.w_, parcel_idx

    srm_outputs = Parallel(n_jobs=-1)(
        delayed(single_parcel_srm)(data_list, parcel_map, parcel_label, n_features) for parcel_label in np.sort(np.unique(parcel_map))
    )

    parcelwise_shared_responses = [s[0] for s in srm_outputs] # concatenate all the parcelwise shared space responses/connectivities

    subject_transforms = [np.zeros((data_list[0].shape[0], n_features)) for i in range(len(data_list))] # empty initalize

    for _, w_, parcel_idx in srm_outputs: # concatenate transforms into subject-wise all-voxel transformation matrices 
        for i,sub in enumerate(subject_transforms):
            sub[parcel_idx,:] = w_[i]
    # i know there's a better way to do that ^ with more linear algebra. urgh

    if save:
        np.save('/scratch/users/csiyer/parcelwise_shared_responses.npy', parcelwise_shared_responses)
        for i,sub in enumerate(subject_transforms):
            np.save(f'/scratch/users/csiyer/{sub_list[i]}_srm_transform.npy', sub)
    
    return subject_transforms, parcelwise_shared_responses 

    
