# Current To-Do:
- Contrast map comparison
- Conjunction analysis

Backburner questions:
- XCP-D?
- DiFuMo atlas instead of Schaefer?
- Different SRM distance penalties (distance as penalty instead of parcelwise? Searchlights instead of parcels?)

[ROADMAP DOC](https://docs.google.com/document/d/13P4QTHxrT5lZfCOXtN59xCKpJfnObtqh3uZkuRqPxR4/edit?pli=1#heading=h.2qncjqtc0b5j)

# Testing

In [None]:
import nibabel as nib
PATH = '/oak/stanford/groups/russpold/data/network_grant/discovery_BIDS_21.0.1/derivatives/output_optcom_MNI/flanker_lev1_output/task_flanker_rtmodel_rt_centered/contrast_estimates/'
testfile = 'sub-s43_task-flanker_contrast-incongruent - congruent_rtmodel-rt_centered_stat-fixed-effects_t-test.nii.gz'
test = nib.load(PATH + testfile)

testfile2 = 'sub-s19_task-flanker_contrast-incongruent - congruent_rtmodel-rt_centered_stat-fixed-effects_t-test.nii.gz'
test2 = nib.load(PATH + testfile2)

In [None]:
"""CONJUNCTION ANALYSIS"""

import os, sys, glob, json, itertools
import numpy as np
import nibabel as nib
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

CONTRAST_PATH = '/oak/stanford/groups/russpold/data/network_grant/discovery_BIDS_21.0.1/derivatives/output_optcom_MNI/'
SRM_DIR = '/scratch/users/csiyer/srm_outputs/'
srm_files = glob.glob(SRM_DIR + '*transform*')

target_contrasts = {
    'cuedTS': 'cuedTS_contrast-task_switch_cost',
    'directedForgetting': 'directedForgetting_contrast-neg-con',
    'flanker': 'flanker_contrast-incongruent - congruent',
    'nBack': 'nBack_contrast-twoBack-oneBack',
    'spatialTS': 'spatialTS_contrast-task_switch_cost',
    'shapeMatching': 'shapeMatching_contrast-main_vars',
    'goNogo': 'goNogo_contrast-nogo_success-go',
    'stopSignal': 'stopSignal_contrast-stop_failure-go'
}

def srm_transform(map, transform, zscore=True):
    out = np.dot(map.get_fdata(), transform)
    if zscore:
        out = StandardScaler().fit_transform(out) 
    return out

def dice_coef(map1, map2):
    """
    Takes in two nibabel Nifti objects (binarized maps); binarizes them and returns:
        1) map of overlapping voxels
        2) count of overlapping voxels
        3) dice coefficient of the maps
    """
    if isinstance(map1, nib.Nifti1Image) and isinstance(map2, nib.Nifti1Image):
        data1 = map1.get_fdata()
        data2 = map2.get_fdata()
    elif isinstance(map1, np.ndarray) and isinstance(map2, np.ndarray):
        data1 = map1
        data2 = map2

    if map1.shape != map2.shape:
        raise ValueError("ERROR: shape mismatch")
    if np.unique(map1) != [0,1] or np.unique(map2) != [0,1]:
        raise ValueError("ERROR: non-binarized maps")

    overlap_map = data1*data2
    intersection = np.sum(overlap_map)
    sum_binarized = np.sum(data1) + np.sum(data2)

    if sum_binarized == 0:
        return 1.0 if intersection == 0 else 0.0
    
    dice = 2.0 * intersection / sum_binarized
    return overlap_map, intersection, dice

output = {}
for task in ['flanker','spatialTS','cuedTS','directedForgetting','stopSignal','goNogo', 'shapeMatching', 'nBack']:
    output[task] = {
        'srm': {'all_overlap': [], 'all_dice': []},
        'nosrm': {'all_overlap': [], 'all_dice': []}
    }
    
    subjects = np.unique(np.load(f'/scratch/users/csiyer/glm_outputs/{task}_subjects.npy'))
    full_task_fname = CONTRAST_PATH + f'{task}_lev1_output/task_{task}_rtmodel_rt_centered/contrast_estimates/'

    # create a data dictionary mapping subject names to both SRM transforms and contrast maps
    sub_dict = {sub: {
        'srm_transform': np.load([s for s in srm_files if sub in s][0]),
        'contrast_map': nib.load( glob.glob(full_task_fname + f'{sub}*{target_contrasts[task]}*t-test.nii.gz')[0] )
    } for sub in subjects}

    for sub1, sub2 in itertools.combinations(subjects, 2):

        _, overlap, dice = dice_coef(sub_dict[sub1]['contrast_map'], sub_dict[sub2]['contrast_map'])
        output[task]['nosrm']['all_overlap'].append(overlap)
        output[task]['nosrm']['all_dice'].append(dice)

        _, overlap, dice = dice_coef(
            srm_transform(sub_dict[sub1]['contrast_map'], sub_dict[sub1]['srm_transform']),
            srm_transform(sub_dict[sub2]['contrast_map'], sub_dict[sub2]['srm_transform']),
        )
        output[task]['srm']['all_overlap'].append(overlap)
        output[task]['srm']['all_dice'].append(dice)

    for method in ['srm', 'nosrm']:
        output[task][method]['avg_overlap'] = np.mean(output[task][method]['all_overlap'])
        output[task][method]['avg_dice'] = np.mean(output[task][method]['all_dice'])
    
# save outputs
OUTPATH = '/scratch/users/csiyer/conjunction_analysis/'
if not os.isdir(OUTPATH):
    os.mkdir(OUTPATH)
with open(OUTPATH + 'outputs.json', 'w') as file:
    json.dump(output, file, indent=4)
    

In [23]:

output = {}
for task in ['flanker','spatialTS','cuedTS','directedForgetting','stopSignal','goNogo', 'shapeMatching', 'nBack']:
    output[task] = {
        'srm': {'all_overlap': [], 'all_dice': []},
        'nosrm': {'all_overlap': [], 'all_dice': []}
    }
    subjects = ['sub-s01', 'sub-s02', 'sub-s03', 'sub-s04', 'sub-s05']

    for sub1, sub2 in itertools.combinations(subjects, 2):
        output[task]['nosrm']['all_overlap'].append(np.random.randint(0,10,5))
        output[task]['nosrm']['all_dice'].append(np.random.randint(0,10,5))
        output[task]['srm']['all_overlap'].append(np.random.randint(0,10,5))
        output[task]['srm']['all_dice'].append(np.random.randint(0,10,5))

    for method in ['srm', 'nosrm']:
        output[task][method]['avg_overlap'] = np.mean(output[task][method]['all_overlap'])
        output[task][method]['avg_dice'] = np.mean(output[task][method]['all_dice'])
    

In [22]:
import numpy as np
import matplotlib.pyplot as plt

def plot_results(output, save=True):
    tasks = list(output.keys())
    fig, axes = plt.subplots(1, 2, figsize=(14, 7))
    
    # Plot for Overlapping Voxels
    axes[0].set_title("Contrast Map Overlapping Voxels")
    axes[0].set_ylabel("Average Overlap")
    axes[0].set_xticks(range(len(tasks)))
    axes[0].set_xticklabels(tasks, rotation=45, ha='right')
    
    # Plot for Dice Coefficients
    axes[1].set_title("Contrast Map Dice Coefficients")
    axes[1].set_ylabel("Average Dice Coefficient")
    axes[1].set_xticks(range(len(tasks)))
    axes[1].set_xticklabels(tasks, rotation=45, ha='right')
    
    # Plot data
    for i, task in enumerate(tasks):
        srm_avg_overlap = output[task]['srm']['avg_overlap']
        nosrm_avg_overlap = output[task]['nosrm']['avg_overlap']
        srm_overlap_err = np.std(output[task]['srm']['all_overlap'])
        nosrm_overlap_err = np.std(output[task]['nosrm']['all_overlap'])

        srm_avg_dice = output[task]['srm']['avg_dice']
        nosrm_avg_dice = output[task]['nosrm']['avg_dice']
        srm_dice_err = np.std(output[task]['srm']['all_dice'])
        nosrm_dice_err = np.std(output[task]['nosrm']['all_dice'])
        
        axes[0].bar(i - 0.2, srm_avg_overlap, yerr=srm_overlap_err, width=0.4, label='SRM' if i == 0 else "", color='green', capsize=5)
        axes[0].bar(i + 0.2, nosrm_avg_overlap, yerr=nosrm_overlap_err, width=0.4, label='No SRM' if i == 0 else "", color='blue', capsize=5)
        
        axes[1].bar(i - 0.2, srm

plot_results(results)

SyntaxError: '(' was never closed (2067608576.py, line 37)