In [1]:
## Resources: 
# Custom transformer: https://medium.com/analytics-vidhya/scikit-learn-pipelines-with-custom-transformer-a-step-by-step-guide-9b9b886fd2cc
# https://nilearn.github.io/auto_examples/03_connectivity/plot_group_level_connectivity.html
# See https://rmldj.github.io/hcp-utils/

import pandas as pd
import numpy as np
from os import path
import os
import hcp_utils as hcp
import nibabel as nib
from nilearn.connectome import ConnectivityMeasure
from nilearn import plotting
from nilearn import datasets
from glob import glob
from functools import reduce
from chord import Chord
import pickle
import sklearn
from sklearn import preprocessing
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn import metrics
from sklearn.metrics import roc_curve, auc, accuracy_score, f1_score, classification_report, balanced_accuracy_score
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import cross_val_score, cross_val_predict, cross_validate, GridSearchCV
from sklearn.feature_selection import RFECV, SelectKBest, f_classif
from sklearn.pipeline import Pipeline
import itertools
from itertools import chain
import seaborn as sns
import statsmodels.api as sm
from statsmodels.formula.api import ols
from statsmodels.multivariate.manova import MANOVA
import statsmodels.stats.multicomp as mc
import matplotlib.pyplot as plt
from IPython.display import Image
import warnings
%matplotlib inline
#print(sklearn.__version__)

pixdim[1,2,3] should be non-zero; setting 0 dims to 1


In [6]:
# Set parameters
atlas_name='yeo17'
warnings.filterwarnings("ignore", category=DeprecationWarning) 

def fxn():
    warnings.warn("deprecated", DeprecationWarning)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fxn()

if atlas_name=='yeo17':
    labs=['VIS_1', 'VIS_2', 'MOT_1', 'MOT_2', 'DAN_1', 'DAN_2', 'SAL_VATTN_1', 'SAL_VATTN_2', 'LIM_1', 'LIM_2', 'FPN_1', 'Control_1', 'Control_2', 'TEMP_PAR', 'DMN_1', 'FPN_2', 'DMN_2', 'Hb_L', 'Hb_R']
elif atlas_name=='yeo7':
    for i, v in enumerate(hcp.yeo7.labels.values()):
        #print(i, v.replace(" ", "_"))
        hcp.yeo7.labels[i]=v.replace(" ", "_")
    labs=[getattr(hcp, atlas_name).labels[x] for x in getattr(hcp, atlas_name).labels if x > 0]
    labs.extend(['Hb_L', 'Hb_R'])
else:
    labs=[getattr(hcp, atlas_name).labels[x] for x in getattr(hcp, atlas_name).labels if x > 0]
    labs.extend(['Hb_L', 'Hb_R'])

In [7]:
# Load demographic/clinical data

dem_path='/ifshome/bwade/NARSAD/Aim_2/data/'
dem=pd.read_csv(dem_path+'U01_Demographics_IDS.csv')

# Get iterable list of subject IDs with time points
sid_list=dem['Subject'].tolist()
arm_list=dem['arm'].tolist()
sid_list=[x for x in sid_list]
arm_list=[x for x in arm_list]

In [8]:
# Function to load and concatenate time series from whole brain and habenula
def img_loader(img_file, hbl_file, hbr_file):
    img=nib.load(img_file)
    img=img.get_fdata() # get data
    hbr=nib.load(hbr_file)
    hbr=hbr.get_fdata()
    hbl=nib.load(hbl_file)
    hbl=hbl.get_fdata()
    img_cat=np.concatenate((img, hbl, hbr), axis=1)  
    return img_cat

# RSFC data generator
def write_baseline_rsfc(sid_list, root_dir, hb_root_dir, atlas):
    
    for index, subj in enumerate(sid_list):
        
        print(index, subj)
                
        im_file_ap_baseline=root_dir + subj + '01' + '/MNINonLinear/Results/rest_acq-AP_run-01/rest_acq-AP_run-01_Atlas_MSMAll_Test_hp2000_clean.dtseries.nii'
        im_file_pa_baseline=root_dir + subj + '01' + '/MNINonLinear/Results/rest_acq-PA_run-02/rest_acq-PA_run-02_Atlas_MSMAll_Test_hp2000_clean.dtseries.nii'
        
        hb_r_file_ap_baseline=hb_root_dir + subj + '01/' + 'HbR_{}_MeanTS_drest-AP.sdseries.nii'.format(subj+'01')
        hb_l_file_ap_baseline=hb_root_dir + subj + '01/' + 'HbL_{}_MeanTS_drest-AP.sdseries.nii'.format(subj+'01')
        hb_r_file_pa_baseline=hb_root_dir + subj + '01/' + 'HbR_{}_MeanTS_drest-PA.sdseries.nii'.format(subj+'01')
        hb_l_file_pa_baseline=hb_root_dir + subj + '01/' + 'HbL_{}_MeanTS_drest-PA.sdseries.nii'.format(subj+'01')            

        if path.exists(im_file_ap_baseline) and path.exists(im_file_pa_baseline) and path.exists(hb_l_file_ap_baseline):
            
            # load ap acquisitions: baseline
            img_ap_baseline_cat=img_loader(img_file=im_file_ap_baseline, 
                                           hbl_file=hb_l_file_ap_baseline, 
                                           hbr_file=hb_r_file_ap_baseline)

            # load pa acquisitions: baseline
            img_pa_baseline_cat=img_loader(img_file=im_file_pa_baseline, 
                               hbl_file=hb_l_file_pa_baseline, 
                               hbr_file=hb_r_file_pa_baseline)                    
            
            # average baseline runs and normalize
            img_baseline=np.mean((img_ap_baseline_cat, img_pa_baseline_cat),axis=0)
            img_baseline_norm=hcp.normalize(img_baseline) 
            
            if np.sum(np.isnan(img_baseline_norm)) >0: 
                print('Normalized data contains NAN Values... passing')
                pass
            
            else:


                ## Atlas Parcellations
                # get regional time series
                hb_timeseries_baseline=img_baseline_norm[:, -2:]            

                img_timeseries_baseline=hcp.parcellate(img_baseline_norm[:,:-2], getattr(hcp, atlas))

                # compute correlation matrices
                correlation_measure=ConnectivityMeasure(kind='partial correlation')
                correlation_matrix_baseline=correlation_measure.fit_transform([np.concatenate((img_timeseries_baseline, hb_timeseries_baseline), axis=1)])[0]                        

                # save matrix
                outfile='/ifshome/bwade/NARSAD/Aim_1/data/atlas_rsfc/' + subj + '_{}_atlas'.format(atlas)
                np.savetxt(outfile, correlation_matrix_baseline, delimiter=',')

        else: pass

write_baseline_rsfc(sid_list=sid_list, root_dir='/nafs/narr/canderson/new_pipeline_test_runs/out/', hb_root_dir='/nafs/narr/HCP_OUTPUT/Habenula/outputs/RSConn_HbSeed/ROI_Timeseries/HbROIs/', atlas=atlas_name)

0 e0010
Normalized data contains NAN Values... passing
1 e0011
2 e0012
3 e0014
4 e0015
5 e0016
6 e0017
7 e0018
8 e0019
9 e0021
10 e0022
11 e0023
12 e0025
13 e0026
14 e0028
15 e0031
16 e0032
17 e0033
18 e0034
19 e0035
20 e0037
21 e0038
22 e0041
23 e0043
24 e0046
25 e0047
26 e0048
27 e0050
28 e0053
29 e0056
30 e0058
31 e0059
32 k0004
33 k0008
34 k0009
35 k0010
36 k0011
37 k0014
38 k0017
39 k0019
40 k0026
41 k0027
42 k0031
43 k0035
44 k0036
45 k0047
46 k0049
47 k0050
48 k0051
49 k0053
50 k0055
51 k0056
52 k0060
53 k0066
54 k0067
55 k0071
56 k0074
57 k0075
58 k0076
59 k0078
60 k0085
61 k0089
62 k0090
63 k0094
64 k0103
65 k0105
66 k0114
67 k0115
68 k0117
69 k0124
70 k0125
71 k0131
72 k0133
73 k0139
74 k0144
75 k0147
76 k0149
77 k0151
78 k0155
79 k0157
80 k0166
81 k0167
82 k0172
83 k0191
84 k0192
85 k0195
86 k0196
87 k0199
88 k0203
89 k0205
90 k0206
91 k0207
92 s0030
93 s0032
94 s0034
95 s0035
96 s0043
97 s0044
98 s0045
99 s0046
100 s0049
101 s0050
102 s0054
103 s0055
104 s0063
105 s0066
106