In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import statsmodels.api as sm
import statsmodels.formula.api as smf
import sys
sys.path.insert(0,'../utils')
from read_math_utils_hbn import *
import ptitprince as pt


In [None]:
rc = {'figure.figsize':(10,5),
      'axes.facecolor':'white',
      'axes.grid' : False,
      'grid.color': '.8',
      'font.family':'Palatino Linotype',
      'font.size' : 15}
plt.rcParams.update(rc)

## Plot Reading/Age Predictions and Model Scores - Figure 4

In [None]:
## Load Data

# model score dfs
ping_scores = pd.read_csv('../data/ping/ping_scores.csv')
abcd_scores = pd.read_csv('../data/abcd/abcd_scores.csv')
hbn_scores = pd.read_csv('../data/hbn/hbn_scores.csv')
hbn_preds_train = pd.read_csv('../data/hbn/hbn_train_preds.csv')
hbn_preds_test = pd.read_csv('../data/hbn/hbn_test_preds.csv')

# model prediction dataframes
read_wm_default_preds = pd.read_csv('../data/hbn/read_wm_preds.csv')
# read_wm_default_scores = pd.read_csv('../data/read_wm_default_scores.csv')
age_wm_default_preds = pd.read_csv('../data/hbn/Age_wm_preds.csv')

### Prediction vs Observed Plots

In [None]:
## Predicted Age from WM + Demos
ax = sns.regplot(age_wm_default_preds.pred,age_wm_default_preds.obs,color='green')

ax.set(xlabel='Predicted Age', ylabel='Observed Age')
# plt.savefig('../figures/age_wm_demo_pred.pdf')

plt.show()


In [None]:
## Predicted Read from WM
ax = sns.regplot(read_wm_default_preds['pred'],read_wm_default_preds['obs'],
                 color='red')

ax.set(xlabel='Predicted Reading Score', ylabel='Observed Reading Score')

plt.show()

### Model Performance Plots

In [None]:
# Set up plotting dataframe from original XGB score files

# merged target column
hbn_scores['target_2'] = hbn_scores['target']+'_'+hbn_scores['set']
ping_scores['target_2'] = ping_scores['target']+'_'+ping_scores['set']
abcd_scores['target_2'] = abcd_scores['target']+'_'+abcd_scores['set']

# add dataset label
hbn_scores['dataset'] = 'HBN'
ping_scores['dataset'] = 'PING'
abcd_scores['dataset'] = 'ABCD'

# merge all three dataframes into one
full_scores_df = hbn_scores.append(ping_scores)
full_scores_df = full_scores_df.append(abcd_scores)
full_scores_df['target_3'] = full_scores_df['target_2'] + '_' + full_scores_df['dataset']

# remove all math/TOWRE related scores from predictions
plot_df = full_scores_df[full_scores_df['target'] !='wiat_math_comp']
plot_df = plot_df[plot_df['target'] !='math']
plot_df = plot_df[plot_df['target'] !='TOWRE_Total_Scaled']

# make target_2 labels consistent across datasets
plot_df.loc[plot_df['target']=='wiat_reading_comp','target'] = 'reading' 
plot_df.loc[plot_df['target_2']=='wiat_reading_comp_train','target_2'] = 'reading_train' 
plot_df.loc[plot_df['target_2']=='wiat_reading_comp_test','target_2'] = 'reading_test' 
plot_df.loc[plot_df['target_2']=='nihtbx_reading_agecorrected_train','target_2'] = 'reading_train' 
plot_df.loc[plot_df['target_2']=='nihtbx_reading_agecorrected_test','target_2'] = 'reading_test'
plot_df.loc[plot_df['target_2']=='interview_age_train','target_2'] = 'age_train' 
plot_df.loc[plot_df['target_2']=='interview_age_test','target_2'] = 'age_test' 

# Make Model label more legible 
plot_df.loc[plot_df['model']=='demo','model'] = 'Demo.'
plot_df.loc[plot_df['model']=='wm','model'] = 'WM'
plot_df.loc[plot_df['model']=='wm_demo','model'] = 'Demo. + WM'

# Make target_2 labels more legible
plot_df.loc[plot_df['target_2']=='reading_train','target_2'] = 'Reading Train'
plot_df.loc[plot_df['target_2']=='reading_test','target_2'] = 'Reading Test'
plot_df.loc[plot_df['target_2']=='age_train','target_2'] = 'Age Train'
plot_df.loc[plot_df['target_2']=='age_test','target_2'] = 'Age Test'

# get scores in the correct order so plot colors show up correctly
plot_df.loc[plot_df['model']=='Demo.','sort'] = 3
plot_df.loc[plot_df['model']=='WM','sort'] = 2
plot_df.loc[plot_df['model']=='Demo. + WM','sort'] = 1

plot_df = plot_df.sort_values(by=['target_2','sort'],ascending=False)


In [None]:
g = sns.catplot(x="model", y="score", col="dataset",
                hue="target_2",palette=['red','red','lightgreen','lightgreen'],
                edgecolor='black',legend=False, # start out w/o legend
                data=plot_df, saturation=.5,
                kind="bar")

# color the bars according to their target variable
for plot in g.axes_dict:

    # Loop over the bars
    for i,thisbar in enumerate(g.axes_dict[plot].patches):
    # Set a different hatch for each bar
        if i%6<=2:
            hatch = '///'
            thisbar.set_hatch(hatch)

# add legend to get hatches 
g.add_legend()

# plt.savefig('../figures/pdfs/hbn_ping_model_scores_no_math.pdf')
# plt.savefig('../figures/pngs/hbn_ping_model_scores_no_math.png')
                                    
plt.show()

## Plot Tract Profiles - Figure 3

In [None]:
# load pheno + harmonized tract profiles df and create fa df and pheno df

tract_pheno_df = pd.read_csv('../data/hbn/combo_df.csv')
# del tract_pheno_df['Unnamed: 0']

diff_df = tract_pheno_df.filter(regex='dki_fa|subjectID')
pheno_df = tract_pheno_df.iloc[:, ~tract_pheno_df.columns.str.contains('dki_')]

In [None]:
## Label participant as low reading/math/both depending on scores

pheno_df['low_reading'] = np.where(pheno_df['wiat_reading_comp'] <85, 1, 0)
pheno_df['low_math'] = np.where(pheno_df['wiat_math_comp'] <85, 1, 0)

pheno_df['score_group'] = 'other'
pheno_df.loc[pheno_df['low_reading'] == 1, 'score_group'] = 'low_r'
pheno_df.loc[pheno_df['low_math'] == 1, 'score_group'] = 'low_m'
pheno_df.loc[(pheno_df['low_reading'] == 1) & (pheno_df['low_math'] == 1), 'score_group'] = 'low_mr'

pheno_df.loc[(pheno_df['Age']<=9),'age_bin'] = 1
pheno_df.loc[(9<pheno_df['Age']) & (pheno_df['Age']<=14),'age_bin'] = 2
pheno_df.loc[14<pheno_df['Age'],'age_bin'] = 3

pheno_df['reading_group'] = np.select(
    [(pheno_df['score_group']=='low_r')|(pheno_df['score_group']=='low_mr')], 
    ['low_r'], 
    default='other')

In [None]:
# pivot diffusion df longer

no_lat_tracts = ['Motor','FA','FP','Occipital','Orbital','PostParietal','SupFrontal','SupParietal','Temporal','AntFrontal']

diff_df_long = (pd.melt(diff_df, id_vars=['subjectID'], value_name='dki_fa'))
diff_df_long['variable'] = diff_df_long['variable'].map(lambda x: x.lstrip('dki_fa_'))
diff_df_long[['tmp','bundle','node']] = diff_df_long.variable.str.split('_', expand=True)

diff_df_long.loc[diff_df_long['tmp'].isin(no_lat_tracts),'node'] = diff_df_long['bundle']
diff_df_long["bundle"] = diff_df_long['tmp']+'_'+diff_df_long['bundle']

diff_df_long.loc[diff_df_long['tmp'].isin(no_lat_tracts), "bundle"] = diff_df_long['tmp']
# # tracts_df_hcp.loc[tracts_df_hcp.tmp == "FP", "bundle"] = tracts_df_hcp['tmp']

# # df1['dki_fa'] = df1.dki_fa.astype(int)
diff_df_long['node'] = pd.to_numeric(diff_df_long['node'])
diff_df_long = diff_df_long.drop(['variable', 'tmp'],axis=1).sort_values(['subjectID','bundle','node'])


In [None]:
tract_pheno_df_wide = merge_behavioral(diff_df_long, pheno_df, "subjectID",['Age','age_bin','reading_group'])

### Plot Tract Profiles - HBN Sample

In [None]:
## MLE controlling for Age, random slope on scanner site
from tqdm import tqdm

tracts = ['ARC_L','ILF_L','SLF_L','Occipital','ARC_R','ILF_R','SLF_R','SupFrontal']
nodes = tract_pheno_df_wide.node.unique()

# Dataframes for holding the beta-weights and upper/lower confidence intervals
beta_df_baseline = pd.DataFrame(columns=['tractID','nodeID','low_r','other','pval'])
lowerCI_df_baseline = pd.DataFrame(columns=['tractID','nodeID','low_r','other'])
upperCI_df_baseline = pd.DataFrame(columns=['tractID','nodeID','low_r','other'])


for tract in tqdm(tracts):
    for node in nodes:
        temp = tract_pheno_df_wide[tract_pheno_df_wide['bundle'] == tract]
        temp = temp[temp['node'] == node]
        
         #specify OLS formula and pull out params/conf. intervals
        md = sm.OLS.from_formula("dki_fa ~ C(reading_group, Treatment(reference='other'))+age_bin", temp)
        mdf = md.fit()
        
        #model coefficients
        coefs = mdf.params
        ci = mdf.conf_int(alpha = 0.05)
        
        #get p-value on beta-weight for reading group. Significance indicates group difference
        pval = mdf.pvalues[1]
         
        other = coefs[0]
        low_r = mdf.params[1]+mdf.params[0]
        
        lower_ci_low_r = low_r - abs(ci[0][1])
        lower_ci_other = ci[0][0]

        upper_ci_low_r = coefs[0] + abs(ci[1][1])
        upper_ci_other = ci[1][0]
        
        lowerCI_row = pd.DataFrame([[tract,node,lower_ci_low_r,lower_ci_other]],columns=['tractID','nodeID','low_r','other'] )
        lowerCI_df_baseline = lowerCI_df_baseline.append(lowerCI_row)
        
        upperCI_row = pd.DataFrame([[tract,node, upper_ci_low_r,upper_ci_other]],columns=['tractID','nodeID','low_r','other'])
        upperCI_df_baseline = upperCI_df_baseline.append(upperCI_row)
        
        row = pd.DataFrame([[tract,node,low_r,other,pval]], columns = ['tractID','nodeID','low_r','other','pval'])
        beta_df_baseline = beta_df_baseline.append(row)

In [None]:
## Plot baseline model
tois = ['Left Arcuate','Right Arcuate',
        'Left ILF','Right ILF',
        'Left SLF', 'Right SLF',
        'Posterior Forceps','Anterior Forceps']

beta_df_wide = pd.melt(beta_df_baseline, id_vars=['tractID','nodeID','pval'], value_vars=['low_r','other'])
beta_df_wide.rename(columns={'value':'DKI FA'}, inplace=True)

lowerCI_df_wide = pd.melt(lowerCI_df_baseline, id_vars=['tractID','nodeID',], value_vars=['low_r','other'])
lowerCI_df_wide.rename(columns={'value':'lowerCI'}, inplace=True)

upperCI_df_wide = pd.melt(upperCI_df_baseline, id_vars=['tractID','nodeID',], value_vars=['low_r','other'])
upperCI_df_wide.rename(columns={'value':'upperCI'}, inplace=True)

beta_df_wide = beta_df_wide.merge(lowerCI_df_wide, on = ["tractID",'nodeID','variable'])
beta_df_wide = beta_df_wide.merge(upperCI_df_wide, on = ["tractID",'nodeID','variable'])

beta_df_wide.rename(columns={'variable':'Reading Group','nodeID':'Node'},inplace=True)
beta_df_wide['Reading Group'] = ['Low Reading Score' if i=='low_r' else 'Average/Above Average Reading Score' for i in beta_df_wide['Reading Group']]

beta_df_wide['tractID'] = np.where(beta_df_wide['tractID']=='ARC_L', 'Left Arcuate', beta_df_wide['tractID'])
beta_df_wide['tractID'] = np.where(beta_df_wide['tractID']=='ARC_R', 'Right Arcuate', beta_df_wide['tractID'])
beta_df_wide['tractID'] = np.where(beta_df_wide['tractID']=='ILF_L', 'Left ILF', beta_df_wide['tractID'])
beta_df_wide['tractID'] = np.where(beta_df_wide['tractID']=='ILF_R', 'Right ILF', beta_df_wide['tractID'])
beta_df_wide['tractID'] = np.where(beta_df_wide['tractID']=='SLF_L', 'Left SLF', beta_df_wide['tractID'])
beta_df_wide['tractID'] = np.where(beta_df_wide['tractID']=='SLF_R', 'Right SLF', beta_df_wide['tractID'])
beta_df_wide['tractID'] = np.where(beta_df_wide['tractID']=='Occipital', 'Posterior Forceps', beta_df_wide['tractID'])
beta_df_wide['tractID'] = np.where(beta_df_wide['tractID']=='SupFrontal', 'Anterior Forceps', beta_df_wide['tractID'])

g = sns.FacetGrid(beta_df_wide[beta_df_wide['tractID'].isin(tois)], col="tractID", col_order = tois,
                  col_wrap=2, hue = 'Reading Group',palette = ['red','grey'])
g.map(sns.lineplot, "Node",'DKI FA')
g.add_legend()

for col_val, ax in g.axes_dict.items(): 
      
    data = beta_df_wide[beta_df_wide['tractID']==col_val]
    nodes = data['Node'].unique().tolist()
    
    low_r_ci = data[data['Reading Group']=='Low Reading Score']
    lower_r_ci = low_r_ci['lowerCI']
    upper_r_ci = low_r_ci['upperCI']
    
    ax.fill_between(nodes,lower_r_ci,upper_r_ci,color = 'red', alpha = 0.2)
    
    other_ci = data[data['Reading Group']=='Average/Above Average Reading Score']
    lower_other_ci = other_ci['lowerCI']
    upper_other_ci = other_ci['upperCI']
    
    ax.fill_between(nodes,lower_other_ci,upper_other_ci, color = 'grey',alpha = 0.2)
    
    x = data[data['pval']<=0.005]['Node']
    y = [0.3] * len(x)
    ax.scatter(x, y, c='red')
    
# g.fig.suptitle('dki_fa ~ Reading Group + Age + (1|Scanner Site)',fontsize=20, y = 1.05)

    
#     x = data[data['pval']<=0.05]['nodeID']
#     y = [0.3] * len(x)
#     ax.scatter(x, y, c='red')
    

# g.fig.suptitle('dki_fa ~ score_group + Age + (1 | Study Site)',fontsize=20, y = 1.05)
g.set_titles(col_template="{col_name}")
# plt.savefig('../figures/tract_profiles_arc_ilf_slf.pdf')

plt.show()

### ABCD

In [None]:
# Load plotting data and filter for varibales we want

final_df_t1_filt = pd.read_csv('../data/abcd/abcd_plot_data.csv')

plot_data_t1 = final_df_t1_filt[['low_r','ARC_L','ARC_R','ILF_L','ILF_R',
                            'SLF_L','SLF_R','CC','NIH_TBX_READ_CORR']]

dif_cols = ['ARC_L','ARC_R','ILF_L','ILF_R','SLF_L','SLF_R','CC']

plot_data_t1 = pd.melt(plot_data_t1, id_vars=['low_r','NIH_TBX_READ_CORR']
                       , value_vars=dif_cols)

In [None]:
tois = ['Left Arcuate','Right Arcuate',
        'Left ILF','Right ILF',
        'Left SLF','Right SLF','Callosum']

plot_data2 = plot_data_t1.copy()
plot_data2.rename(columns={'value':'FA','variable':'tractID','low_r':'Reading Score Group'},inplace=True)
plot_data2['Reading Score Group'] = ['Low' if i=='low_r' else 'Average/High' for i in plot_data2['Reading Score Group']]

plot_data2['tractID'] = np.where(plot_data2['tractID']=='ARC_L', 'Left Arcuate', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='ARC_R', 'Right Arcuate', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='ILF_L', 'Left ILF', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='ILF_R', 'Right ILF', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='SLF_L', 'Left SLF', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='SLF_R', 'Right SLF', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='CC', 'Callosum', plot_data2['tractID'])

plot_data2 = plot_data2.sort_values(by=['Reading Score Group'],ascending=True)

g = sns.FacetGrid(plot_data2[plot_data2['tractID'].isin(tois)],
                  col="tractID", col_wrap=2,col_order=tois)

g.map_dataframe(pt.RainCloud, 'Reading Score Group',"FA",orient='v',palette = ['grey','red'],
      bw = 0.2,width_viol = .5)
g.add_legend()
sns.set_style("white")


g.set_titles(col_template="{col_name}")
plt.show()

### PING

In [None]:
plot_data = pd.read_csv('../data/ping/final_df_ping.csv')

In [None]:
plot_data = plot_data.filter(regex=r'(dti_fiber_fa|low_r|age)')
plot_data = plot_data[plot_data.columns.drop(list(plot_data.filter(regex='all')))]

dif_cols = plot_data.filter(regex=r'(dti_fiber_fa)').columns

plot_data = pd.melt(plot_data, id_vars=['low_r','interview_age'], value_vars=dif_cols)


In [None]:
tois = ['Left Arcuate','Right Arcuate',
        'Left ILF','Right ILF',
        'Left SLF','Right SLF','Callosum']

plot_data2 = plot_data.copy()
plot_data2.rename(columns={'value':'FA','variable':'tractID','low_r':'Reading Score Group'},inplace=True)
plot_data2['Reading Score Group'] = ['Low' if i=='low_r' else 'Average/High' for i in plot_data2['Reading Score Group']]

plot_data2['tractID'] = np.where(plot_data2['tractID']=='dti_fiber_fa_l_tslf', 'Left Arcuate', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='dti_fiber_fa_r_tslf', 'Right Arcuate', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='dti_fiber_fa_l_ilf', 'Left ILF', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='dti_fiber_fa_r_ilf', 'Right ILF', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='dti_fiber_fa_l_slf', 'Left SLF', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='dti_fiber_fa_r_slf', 'Right SLF', plot_data2['tractID'])
plot_data2['tractID'] = np.where(plot_data2['tractID']=='dti_fiber_fa_cc', 'Callosum', plot_data2['tractID'])



g = sns.FacetGrid(plot_data2[plot_data2['tractID'].isin(tois)],
                  col="tractID", col_wrap=2,col_order=tois)

g.map_dataframe(pt.RainCloud, 'Reading Score Group',"FA",orient='v',palette = ['grey','red'],
      bw = 0.2,width_viol = .5)
g.add_legend()
g.set_titles(col_template="{col_name}")
sns.set_style("white")  
plt.show()
