# Current To-Do:
- Conjunction analysis LOSO version

Backburner questions:
- XCP-D?
- DiFuMo atlas instead of Schaefer?
- Different SRM distance penalties (distance as penalty instead of parcelwise? Searchlights instead of parcels?)

[ROADMAP DOC](https://docs.google.com/document/d/13P4QTHxrT5lZfCOXtN59xCKpJfnObtqh3uZkuRqPxR4/edit?pli=1#heading=h.2qncjqtc0b5j)

# Testing

In [None]:
from contrast_map_checks.conjunction_analysis import *
sub_files = load_all_filenames() 
parcel_map = load_parcel_map(n_dimensions = 100)
task = 'flanker'
loso_idx = 0
loso_sub = list(sub_files.keys())[0]
sub_list = list(sub_files.keys())
data_list = [np.load(v['connectome']) for s,v in sub_files.items()]

sub_transforms = compute_loso_srm(data_list, sub_list, loso_sub, parcel_map, n_features=50, save=False)
contrast_maps = [mask(sub_files[s][task]) for s in sub_list]
contrast_maps_srm = [
    m if s == loso_sub else srm_and_loso_native(m, t, sub_transforms[loso_idx])
    for s, m, t in zip(sub_list, contrast_maps, sub_transforms)
]

pairs = itertools.combinations(range(len(sub_list)), 2)
print(
    [dice_coef(contrast_maps[i], contrast_maps[j], threshold_val = threshold_val) for i,j in pairs]
)
print(
    [dice_coef(contrast_maps_srm[i], contrast_maps_srm[j], threshold_val = threshold_val) for i,j in pairs]
)
print(
    [pearsonr(contrast_maps[i], contrast_maps[j])[0] for i,j in pairs]
)
print(
    [pearsonr(contrast_maps_srm[i], contrast_maps_srm[j])[0] for i,j in pairs]
)

In [None]:
loso_idx = sub_list.index(loso_sub)
loso_data = data_list[loso_idx]
train_data = [d for i,d in enumerate(data_list) if i != loso_idx]
temp_dir = '/scratch/users/csiyer/temp_dir'
def single_parcel_srm(train_data, loso_data, parcel_map, parcel_label, n_features):
    parcel_idxs = np.where(parcel_map == parcel_label)[0]
    train_data_parcel = [d[parcel_idxs] for d in train_data]
    srm = FastSRM(n_components=n_features, n_iter=20, n_jobs=1, aggregate='mean', temp_dir = temp_dir)
    reduced_sr = srm.fit_transform(train_data_parcel)
    srm.aggregate = None
    srm.add_subjects([loso_data], reduced_sr)
    return [np.load(x) for x in srm.basis_list], parcel_idxs # return list of all the transforms, which are saved to temp_dir

srm_outputs = Parallel(n_jobs=-1)(
    delayed(single_parcel_srm)(train_data, loso_data, parcel_map, parcel_label, 50) for parcel_label in np.sort(np.unique(parcel_map))
)

subject_transforms = [np.zeros((len(parcel_map), 50)) for _ in range(len(sub_list))] # empty initalize
for sub_weights, parcel_idxs in srm_outputs: # for each parcel
    for i,sub in enumerate(subject_transforms): # for each subject
        sub[parcel_idxs,:] = sub_weights[i] # add those parcel's transformation values to that subject's transformation matrix

# we need to reorder them to match the original sub_list. the loso_sub's transform is at the end, and we'll insert it at its original index
subject_transforms.insert(loso_idx, subject_transforms.pop())
