# Spatial reproducibility between the training and the test set 

In [None]:
import os
import gc
import h5py
import glob
import dypac
import pickle
import numpy as np
import seaborn as sns
import pandas as pd
import nibabel as nb
from nilearn import plotting
from scipy.stats import pearsonr
from nilearn.input_data import NiftiMasker


import matplotlib.pyplot as plt
from nilearn import plotting
from nilearn.masking import unmask
from scipy.optimize import linear_sum_assignment

Match the components from teh first set (training set) with the components of the second set (test set)

## Load the Dypac components

In [None]:
bg_path = 'path to mask'

main_path = 'main folder'
output_folder_within_rep = os.path.join(main_path, 'within-subject-reproducibility')
output_folder_btw_rep = os.path.join(main_path, 'between-subject-reproducibility')

In [None]:
list_subject = ['sub-01','sub-02']

list_clusters = [20]
list_states = [150]
list_fwhm = [8]

In [None]:
list_fwhm

In [None]:
def match_states(maps1, maps2):
    """
    Match states computed based on two independent sets of data
    based on the Hungarian matching algorithm
    :param maps1: Dask array, shape (n_states, n_voxels)
            n_states's dynamic parcel for each state for the set 1 of sessions
    :param maps2: Dask array, shape (n_states, n_voxels)
            n_states's dynamic parcel for each state for the set 2 of sessions
    :return: float
        the average correlation related to this medoid
    """
    n_states = maps1.shape[0]

    states_cost_mtx = np.zeros((n_states, n_states))
    states_corr_mtx = np.zeros((n_states, n_states))

    n_par_state1 = 0
    
    for map1_idx in range(n_states):
        
        map1 = maps1[map1_idx, :]
        
        n_voxels1 = map1.shape[0]
        n_par_state2 = 0
        for map2_idx in range(n_states):
            
            map2 = maps2[map2_idx, :]
            
            corr, _ = pearsonr(map1, map2) 
            
            if np.isnan(corr):
                continue
            
            states_cost_mtx[n_par_state1, n_par_state2] = (1 - corr)
            
            # Correlation between pairwise states
            states_corr_mtx[n_par_state1, n_par_state2] = corr

            n_par_state2 += 1
        n_par_state1 += 1

    row_ind, col_ind = linear_sum_assignment(states_cost_mtx)

    l_scores_states = []
    l_tuple_matched_states = []
    for n_state in range(n_states):
        l_scores_states.append(states_corr_mtx[row_ind[n_state], col_ind[n_state]])
        l_tuple_matched_states.append(tuple( (row_ind[n_state], col_ind[n_state])))

    return l_tuple_matched_states, l_scores_states

# Within-subject level spatial reproducibility 

In [None]:

for subject in list_subject:
    
    out_filename_within = os.path.join(output_folder, 'within-'+subject+'-reproducibility-fwhm-5.h5')

    h5_file = h5py.File(out_filename_within, 'a')
    
    h5_subject = h5_file.create_group(subject)
    
    ###  Load mask for each subject
    bg_img = os.path.join(bg_path, subject + '_space-MNI152NLin2009cAsym_label-GM_probseg.nii.gz')
    
    bg_img = nb.load(bg_img)

    for fwhm in list_fwhm:
        
        h5_fwhm = h5_subject.create_group(str(fwhm))
        
        for cluster in list_clusters:
            
            h5_cluster = h5_fwhm.create_group(str(cluster))

            for states in list_states:
                
                
                h5_states = h5_cluster.create_group(str(states))

                ## TRAINING DATA COMPONENTS
                file_ext = subject+ '_dataset-friends_tasks-s01even_cluster-'+str(cluster)+'_states-'+str(states)+'_batches-1_reps-100_fwhm-'+str(fwhm)+'.pickle'

                files_td = os.path.join(main_path, 'models-embeddings-friends', file_ext) 

                list_files_td = glob.glob(files_td, recursive=True)
                
                print(files_td)

                if len(list_files_td) > 0:
                    
                    if os.path.exists(list_files_td[0]):

                        with (open(list_files_td[0], "rb")) as openfile:

                            model_td = pickle.load(openfile)

                        mask_binary = bg_img.get_data() > 0

                        mask_binary = mask_binary.astype(int)

                        n_voxels = len(np.where(mask_binary==1)[0])

                        mask_img = nb.Nifti1Image(mask_binary, bg_img.affine)

                        del mask_binary

                        nifti_masker = NiftiMasker(mask_img=mask_img, standardize=False,
                                               smoothing_fwhm=None, detrend=False,
                                               memory="nilearn_cache", memory_level=1)

                        del mask_img
                        gc.collect()

                        nifti_masker.fit(bg_img)

                        stable_cp_td = np.zeros((model_td.components_.shape[0], n_voxels))

                        print('shape: ', model_td.components_.shape[0])

                        for idx in range(model_td.components_.shape[0]):

                            img = model_td.masker_.inverse_transform(model_td.components_[idx,:].todense())

                            stable_cp_td[idx, :] = nifti_masker.transform(img)

                        del img   
                        del model_td
                        gc.collect()

                        file_ext = subject + '_dataset-friends_tasks-s01odd_cluster-'+str(cluster)+'_states-'+str(states)+'_batches-1_reps-100_fwhm-'+str(fwhm)+'.pickle'

                        files_tstd = os.path.join(main_path, 'models-embeddings-friends', file_ext) 

                        list_files_tstd= glob.glob(files_tstd, recursive=True)

                        with (open(list_files_tstd[0], "rb")) as openfile:

                            model_tstd = pickle.load(openfile)

                        stable_cp_tstd = np.zeros((model_tstd.components_.shape[0], n_voxels))

                        for idx in range(model_tstd.components_.shape[0]):

                            img = model_tstd.masker_.inverse_transform(model_tstd.components_[idx,:].todense())

                            stable_cp_tstd[idx, :] = nifti_masker.transform(img)

                        del img
                        del nifti_masker
                        del model_tstd
                        gc.collect()

                        ## MATCHING COMPONENTS AT THE WITHIN-SUBJECT LEVEL 
                        tuple_matched_states, scores_states = match_states(stable_cp_td, stable_cp_tstd)

                        del stable_cp_td
                        del stable_cp_tstd
                        gc.collect()

                        print(scores_states)

                    h5_states.create_dataset('Matched-state-maps', data=np.asarray(tuple_matched_states))

                    h5_states.create_dataset('States-reproducibility', data=np.asarray(scores_states))

                    del scores_states
                    del tuple_matched_states
                    gc.collect()
                    
                

h5_file.close()

In [None]:
high_idx_scores = np.where(scores_states>=0.8)

# Between-subject spatial reproducibility

In [None]:
for sub_id1 in range(len(list_subject)):
    
    for sub_id2 in range(sub_id1+1, len(list_subject)):
        
        subject1 = list_subject[sub_id1]
        
        subject2 = list_subject[sub_id2]
        
        print(subject1 +'-'+subject2)    
        
        for fwhm in list_fwhm:
                
                out_filename_between = os.path.join(output_folder_btw_rep, 
                            'draft-between-subjects-reproducibility', 'between-subjects-reproducibility' +subject1+'-'+subject2+'-reproducibility-fwhm'+str(fwhm)+'.h5')

                h5_file = h5py.File(out_filename_between, 'a')
                
                h5_subject = h5_file.create_group(subject1 +'-'+subject2)
                
                h5_fwhm = h5_subject.create_group(str(fwhm))

                for cluster in list_clusters:
                    
                        #try:

                            h5_cluster = h5_fwhm.create_group(str(cluster))

                            for states in list_states:

                                h5_states = h5_cluster.create_group(str(states))

                                ###  Load mask for each subject
                                bg_img1 = os.path.join(bg_path, subject1 + '_space-MNI152NLin2009cAsym_label-GM_probseg.nii.gz')
                                bg_img2 = os.path.join(bg_path, subject2 + '_space-MNI152NLin2009cAsym_label-GM_probseg.nii.gz')

                                ## TRAINING DATA COMPONENTS
                                file_ext1 = subject1 + '_dataset-friends_tasks-s01odd_cluster-'+str(cluster)+'_states-'+str(states)+'_batches-1_reps-100_fwhm-'+str(fwhm)+'.pickle'

                                files_td1 = os.path.join(main_path, 'models-embeddings-friends', file_ext1) 

                                print(files_td1)
                                
                                list_files_td1= glob.glob(files_td1, recursive=True)
                                
                                print(list_files_td1)

                                with (open(list_files_td1[0], "rb")) as openfile:

                                    model_td = pickle.load(openfile)

                                bg_img1 = nb.load(bg_img1)
                                bg_img2 = nb.load(bg_img2)

                                mask_binary1 = bg_img1.get_data() > 0
                                mask_binary2 = bg_img2.get_data() > 0

                                mask_binary = mask_binary1 == mask_binary2

                                mask_binary = mask_binary.astype(int)

                                bg_img = nb.Nifti1Image(mask_binary, bg_img1.affine)

                                n_voxels = len(np.where(mask_binary==1)[0])

                                print('n_voxels')
                                print(n_voxels)

                                mask_img = nb.Nifti1Image(mask_binary, bg_img.affine)

                                nifti_masker = NiftiMasker(mask_img=mask_img, standardize=False,
                                                       smoothing_fwhm=None, detrend=False,
                                                       memory="nilearn_cache", memory_level=1)

                                nifti_masker.fit(bg_img)

                                stable_cp_td = np.zeros((model_td.components_.shape[0], n_voxels))

                                for idx in range(model_td.components_.shape[0]):

                                    img =  model_td.masker_.inverse_transform(model_td.components_[idx,:].todense())

                                    stable_cp_td[idx, :] = nifti_masker.transform(img)

                                gc.collect()

                                ## TEST DATA COMPONENTS
                                file_ext2 = subject2 + '_dataset-friends_tasks-s01even_cluster-'+str(cluster)+'_states-'+str(states)+'_batches-1_reps-100_fwhm-'+str(fwhm)+'.pickle'

                                files_td2 = os.path.join(main_path, 'models-embeddings-friends', file_ext2)

                                list_files_tstd= glob.glob(files_td2, recursive=True)

                                with (open(list_files_tstd[0], "rb")) as openfile:

                                    model_tstd = pickle.load(openfile)

                                stable_cp_tstd = np.zeros((model_tstd.components_.shape[0], n_voxels))

                                for idx in range(model_tstd.components_.shape[0]):

                                    img = model_tstd.masker_.inverse_transform(model_tstd.components_[idx,:].todense())

                                    stable_cp_tstd[idx, :] = nifti_masker.transform(img)

                                del img
                                del model_tstd
                                gc.collect()

                                ## MATCHING COMPONENTS AT THE WITHIN-SUBJECT LEVEL 
                                tuple_matched_states, scores_states = match_states(stable_cp_td, stable_cp_tstd)
                                del stable_cp_tstd
                                del stable_cp_td
                                gc.collect()

                                h5_states.create_dataset('Matched-state-maps', data=np.asarray(tuple_matched_states))

                                h5_states.create_dataset('States-reproducibility', data=np.asarray(scores_states))


                                del nifti_masker
                                del tuple_matched_states

                                gc.collect()
                                
                                print(h5_file.keys())
                        #except:
                                
                        #    continue
h5_file.close()

In [None]:
scores_states

In [None]:
a