In [1]:
import numpy as np
np.random.seed(42)
import pandas as pd
import os
from tqdm import tqdm
from scipy.stats import spearmanr, pearsonr
from pingouin import partial_corr
np.random.seed(42)
from itertools import combinations
%matplotlib widget
import matplotlib.pyplot as plt
import statsmodels.api as sm
import random
random.seed(42)

In [2]:
### table1 info
df_table1_all = pd.read_csv('table1_information_deidentified.csv', index_col=0)
print(df_table1_all.shape)

(623, 1255)


In [3]:
df_table1_all['Age'] = df_table1_all.age

In [4]:
df_table1_all.head(2)

Unnamed: 0_level_0,InstitutionCode,CDACLastModifiedDTS,age,cci_score_unweighted,cci_score_weighted,dx_cci_mi,dx_cci_chf,dx_cci_pvd,dx_cci_cevd,dx_cci_dementia,...,so_pos_amp_f,so_pos_amp_c,so_pos_amp_o,so_rate_f,so_rate_c,so_rate_o,so_slope_neg2_f,so_slope_neg2_c,so_slope_neg2_o,Age
sid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
sid00,1,2022-05-10 13:03:27.1613110,66.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,234.319868,243.988988,268.167189,5.305994,5.029968,4.473186,1809.564362,1842.692265,1974.20975,66.8
sid01,1,2022-05-10 13:03:01.0037450,58.0,,,,,,,,...,38.866811,35.956362,42.167286,1.304802,0.304802,0.006263,322.30118,284.580296,192.42747,58.0


In [5]:
df = pd.read_csv('eeg_mri_cognition_deidentified.csv', index_col=0)
print(df.shape)
df.head(2)

(623, 853)


Unnamed: 0_level_0,interval_mri-eeg_abs1,report_date_time,report_description,report_status,report_type,age,sex,bmi,ahi,medbenzo,...,vol-ctx--insula,vol-total_ventricle,vol-ctx--anterior,vol-striatum,alpha_bandpower_mean_o_w,dt_eeg_mri_abs,dt_mmse_eeg_abs,dt_mmse_mri_abs,dt,dt_abs
sid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
sid00,258,,,,,66.832877,0,32.37,8.9,0,...,0.008634,0.026705,0.015385,0.016066,75.915455,0.706849,3.306849,2.6,0.706849,0.706849
sid01,1280,,,,,57.983562,1,28.0,4.5,0,...,0.008282,0.028709,0.013872,0.012763,13.780639,3.506849,3.654795,7.161644,3.506849,3.506849


In [6]:
df = df.join(df_table1_all[['dx_dementia', 'dx_no_dementia_cci<2', 'dx_dementia_cat']], on='sid')
df['dx_dementia'] = df['dx_dementia'].astype(int)
df['dx_no_dementia_cci<2'] = df['dx_no_dementia_cci<2'].astype(int)

In [7]:
print(df.columns[[14, 15, -66, -65]])


Index(['mmse', 'mean_gradient_f3-m2_w', 'fs_symm_c', 'vol-total intracranial'], dtype='object')


In [8]:
cols_mri = list([x for x in df.columns if 'vol-' in x])

idx_sleep_start = 15
idx_sleep_end = -66
print(df.columns[[14, 15, -66, -65]])
cols_sleep = list(df.columns[idx_sleep_start:idx_sleep_end]) + ['alpha_bandpower_mean_o_w']
print(len(cols_mri))

print(len(cols_sleep))


Index(['mmse', 'mean_gradient_f3-m2_w', 'fs_symm_c', 'vol-total intracranial'], dtype='object')
56
776


In [9]:
var_group1 = 'dx_dementia'
var_group2 = 'dx_no_dementia_cci<2'

df_sleep = df[cols_sleep]
df_mri = df[cols_mri]


In [10]:
df_table1_all = df_table1_all.join(df['sex'], on='sid')

df_dx = df_table1_all[[var_group1, var_group2, 'dx_dementia_cat']]
df_dx = df_dx.loc[df_sleep.index]
df_dx['dx_dementia'] = df_dx['dx_dementia'].astype(int)

df_covariates = df_table1_all[['sex', 'age']].copy()
df_covariates['sex'] = (df_covariates['sex'] == 'Male').astype(int)

df_covariates = df_covariates.loc[df_sleep.index]

### MMSE data

### what we need here:
df_sleep: sleep features  
df_mri: mri features  
df_dx: disease status  
df_covariates: age, sex  

In [11]:
assert all(df_sleep.index == df_mri.index), "SID must match"
assert all(df_sleep.index == df_dx.index), "SID must match"
assert all(df_sleep.index == df_covariates.index), "SID must match"
print(f"N = {len(df_sleep)}")

N = 623


In [12]:
df_data = pd.concat([df_sleep, df_mri, df_covariates, df_dx[['dx_dementia', 'dx_dementia_cat']]], axis=1)
# df_data = df

In [13]:
sids_control = df_dx[(df_dx['dx_no_dementia_cci<2'] == 1) & \
                     (np.isin(df_dx['dx_dementia_cat'], ['excluded', 'no dementia']))].index
sids_dementia = df_dx[df_dx['dx_dementia'] == 1].index
df_data = df_data.loc[list(sids_dementia) + list(sids_control)]

print(f"N controls: {len(sids_control)}")
print(f"N dementia+MCI: {len(sids_dementia)}")

# restrict data to min age dementia
min_age_dementia = df_data.query("dx_dementia == 1").age.min()
print(min_age_dementia)
df_data = df_data.query(f"age >= {min_age_dementia}")
sids_control = df_data.query("dx_dementia == 0").index
sids_dementia = df_data.query("dx_dementia == 1").index

sids_mci = df_data.query("dx_dementia == 1").query('dx_dementia_cat == "mci"').index
sids_dementia = df_data.query("dx_dementia == 1").query('dx_dementia_cat == "dementia"').index
print('MCI, Dementia', len(sids_mci), len(sids_dementia))

print('after age restriction')
print(f"N controls: {len(sids_control)}")
print(f"N dementia, MCI: {len(sids_dementia), len(sids_mci)}")

print(f"Dementia Category controls: \n{df_data.loc[sids_control, 'dx_dementia_cat'].value_counts()}")
print(f"\nDementia Category demented: \n{df_data.loc[sids_dementia, 'dx_dementia_cat'].value_counts()}")

N controls: 152
N dementia+MCI: 107
42.2
MCI, Dementia 71 36
after age restriction
N controls: 127
N dementia, MCI: (36, 71)
Dementia Category controls: 
dx_dementia_cat
no dementia    90
excluded       37
Name: count, dtype: int64

Dementia Category demented: 
dx_dementia_cat
dementia    36
Name: count, dtype: int64


In [14]:
df_data.age.median()

61.55

In [15]:
age_min = np.ceil(min_age_dementia)
age_max = np.floor(df_data.age.max())
print(age_min, age_max)
age_range = np.arange(age_min, age_max)
idx_medianage = 18
print(age_range[idx_medianage])

43.0 87.0
61.0


### fit model 

First, with interactin term:
feature = f(disease, age, sex, disease*age)  
prediction_dementia = f(1, age_range, 0.5, 1*age)  
prediction_control = f(0, age_range, 0.5, 0*age)

However, slope is not significantly different between dementia and non-dementia group. therefore, we fit models without disease*age interaction term and show these. star indicates different offset.


In [16]:
structure_vars = ['vol-' + x for x in ['thalamus', 'hippocampus', 'ctx--anterior', 'ctx--isthmuscingulate',
                 'amygdala', 'brain-stem', 'total_ventricle']]
sleep_vars = [x.lower() for x in ['slowdelta_bandpower_total', 'SO_RATE_F', 'SS_DENS_F', 'FS_DENS_C', 'perc_r',
              'alpha_bandpower_mean_O_W']]
cognition_vars = ['mmse']

In [17]:
def model_and_plot_routine_per_feature(df_data, feature, interaction_age_disease=True, plot=True):

    vars_predictors = ['age', 'sex', 'dx_dementia']
    df_data_feature = df_data[[feature] + vars_predictors].copy()
    if interaction_age_disease:
        df_data_feature['age_disease'] = df_data_feature['age'] * df_data_feature['dx_dementia']
        vars_predictors += ['age_disease']
        
    mod = sm.OLS(df_data_feature[feature], sm.add_constant(df_data_feature[vars_predictors]))
    res = mod.fit()

    df_artificial = pd.DataFrame(columns=['const'] + vars_predictors)
    df_artificial['age'] = age_range
    df_artificial['sex'] = 0.5
    df_artificial['const'] = 1

    # get predictions (linefits) for controls and dementia
    df_artificial['dx_dementia'] = 0
    if interaction_age_disease: 
        df_artificial['age_disease'] = df_artificial['age'] * df_artificial['dx_dementia']
    preds_control = res.predict(df_artificial)
    slope_control = preds_control[1] - preds_control[0]
    value_medianage_control = preds_control[idx_medianage]
    
    df_artificial['dx_dementia'] = 1
    if interaction_age_disease:
        df_artificial['age_disease'] = df_artificial['age'] * df_artificial['dx_dementia']
    preds_dementia = res.predict(df_artificial)
    slope_dementia = preds_dementia[1] - preds_dementia[0]
    value_medianage_dementia = preds_dementia[idx_medianage]
    
    if plot:
        alpha_scatter = 0.6

        fig, ax = plt.subplots(1, 1, figsize=(7, 4))
        ax.plot(age_range, preds_control, c='orange')
        ax.plot(age_range, preds_dementia, c='gray')

        df_data_feature_control = df_data_feature[df_data_feature['dx_dementia'] == 0]
        df_data_feature_dementia = df_data_feature[df_data_feature['dx_dementia'] == 1]

        ax.scatter(df_data_feature_control['age'], df_data_feature_control[feature], 
                   s=3, color='orange', alpha=alpha_scatter, marker='o')
        ax.scatter(df_data_feature_dementia['age'], df_data_feature_dementia[feature], 
                   s=2.75, color='gray', alpha=alpha_scatter, marker='s')
        ax.set_title(feature)
        

        if interaction_age_disease:
            (ci_control, ci_dementia, significant, symbol) = results[(feature, 'slope')]
            precision = int(np.ceil(max(-np.log10(abs(min(slope_control, slope_dementia))), 0))) + 1
            ci_control = np.round(np.array(ci_control), precision)
            ci_dementia = np.round(np.array(ci_dementia), precision)
            ax.text(1, 1,
                     f"Slopes:   {symbol}                                                     \n\
                    Control {np.round(slope_control, precision)} [{ci_control[0]}, {ci_control[1]}]\n\
                     Dementia {np.round(slope_dementia, precision)} [{ci_dementia[0]}, {ci_dementia[1]}]",
                    transform=ax.transAxes, ha='right', va='top', fontsize=6)

        
        (ci_control, ci_dementia, significant, symbol) = results[(feature, 'value_median_age')]
        precision = int(np.ceil(max(-np.log10(abs(min(value_medianage_control, value_medianage_dementia))), 0))) + 1
        ci_control = np.round(np.array(ci_control), precision)
        ci_dementia = np.round(np.array(ci_dementia), precision)
        ax.text(1, 0.9,
                 f"Value at median age:   {symbol}                                        \n\
                Control {np.round(value_medianage_control, precision)} [{ci_control[0]}, {ci_control[1]}]\n\
                 Dementia {np.round(value_medianage_dementia, precision)} [{ci_dementia[0]}, {ci_dementia[1]}]",
                transform=ax.transAxes, ha='right', va='top', fontsize=6)
        
    return slope_control, value_medianage_control, slope_dementia, value_medianage_dementia

### bootstrappin'

In [18]:
def compute_overlap(ci_control, ci_dementia):
    
    significant = True
    if ci_control[0] <= ci_dementia[0] <= ci_control[1]:
        significant = False
    elif ci_control[0] <= ci_dementia[1] <= ci_control[1]:
        significant = False
    elif (ci_control[0] <= ci_dementia[0]) & (ci_control[1] >= ci_dementia[1]):
        significant = False
    elif (ci_control[0] >= ci_dementia[0]) & (ci_control[1] <= ci_dementia[1]):
        significant = False

    return significant

In [19]:
def bootstrappin():

    n_bootstraps = 500
    k = len(df_data)
    index_original = np.arange(k)

    results = {}

    for feature in structure_vars + sleep_vars:
        # print(f'\n{feature}')
        list_slope_control = []
        list_value_medianage_control = []
        list_slope_dementia = []
        list_value_medianage_dementia = []

        for n in range(n_bootstraps):
            index_bootstrap = random.choices(index_original, k=len(df_data)) # random selection (with replacement) of index
            df_data_bootstrap = df_data.iloc[index_bootstrap]
            slope_control, value_medianage_control, slope_dementia, value_medianage_dementia = model_and_plot_routine_per_feature(df_data_bootstrap,
                                                                                                                                  feature,
                                                                                                                                  interaction_age_disease=interaction_age_disease,
                                                                                                                                  plot=False)
            list_slope_control.append(slope_control)
            list_value_medianage_control.append(value_medianage_control)
            list_slope_dementia.append(slope_dementia)
            list_value_medianage_dementia.append(value_medianage_dementia)

        ci_slope_control = [np.percentile(list_slope_control, 2.5), np.percentile(list_slope_control, 97.5)]
        ci_slope_dementia = [np.percentile(list_slope_dementia, 2.5), np.percentile(list_slope_dementia, 97.5)]
        significant = compute_overlap(ci_slope_control, ci_slope_dementia)
        symbol = '*' if significant else ''
        # print('slope:')
        # print(ci_slope_control)
        # print(ci_slope_dementia)
        # print(significant)
        results[(feature, 'slope')] = (ci_slope_control, ci_slope_dementia, significant, symbol)

        ci_median_control = [np.percentile(list_value_medianage_control, 2.5), np.percentile(list_value_medianage_control, 97.5)]
        ci_median_dementia = [np.percentile(list_value_medianage_dementia, 2.5), np.percentile(list_value_medianage_dementia, 97.5)]
        significant = compute_overlap(ci_median_control, ci_median_dementia)
        symbol = '*' if significant else ''

        # print('median age value:')
        # print(ci_median_control)
        # print(ci_median_dementia)
        # print(significant)
        results[(feature, 'value_median_age')] = (ci_median_control, ci_median_dementia, significant, symbol)
        
    return results

In [20]:
interaction_age_disease = True
results = bootstrappin()

%matplotlib inline
plt.close('all')
for feature in structure_vars + sleep_vars:
    slope_control, value_medianage_control, slope_dementia, value_medianage_dementia = model_and_plot_routine_per_feature(df_data, 
                                                                                                                          feature,
                                                                                                                      interaction_age_disease=interaction_age_disease, plot=True)

### Slope was never significantly different. Hence, we fit a model without the age * disease interaction term.

In [21]:
interaction_age_disease = False
results = bootstrappin()

%matplotlib inline
plt.close('all')
for feature in structure_vars + sleep_vars:
    slope_control, value_medianage_control, slope_dementia, value_medianage_dementia = model_and_plot_routine_per_feature(df_data, 
                                                                                                                          feature,
                                                                                                                          interaction_age_disease=interaction_age_disease,
                                                                                                                         plot=True)

In [22]:
print("SIGNIFICANT OFFSETS:")
for feature in structure_vars + sleep_vars:
    if results[feature, 'value_median_age'][2] == True:
        print(feature)

SIGNIFICANT OFFSETS:
vol-thalamus
vol-hippocampus
vol-amygdala
vol-brain-stem
vol-total_ventricle


In [23]:
vars_to_plot = ['vol-thalamus', 'vol-hippocampus', 'vol-ctx--anterior',
                'vol-ctx--isthmuscingulate', 'vol-amygdala', 'vol-brain-stem',
                'vol-total_ventricle',
                'slowdelta_bandpower_total', 'so_rate_f',
                'ss_dens_f', 'fs_dens_c',
                'perc_r', 'alpha_bandpower_mean_o_w',
               ]

df_data = df_data.loc[:, ['age'] + vars_to_plot]

df_data.columns

Index(['age', 'vol-thalamus', 'vol-hippocampus', 'vol-ctx--anterior',
       'vol-ctx--isthmuscingulate', 'vol-amygdala', 'vol-brain-stem',
       'vol-total_ventricle', 'slowdelta_bandpower_total', 'so_rate_f',
       'ss_dens_f', 'fs_dens_c', 'perc_r', 'alpha_bandpower_mean_o_w'],
      dtype='object')

In [24]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# Create a scaler object
scaler = StandardScaler()
# Standardize the data
X_standardized = scaler.fit_transform(df_data[vars_to_plot])

# Perform PCA on X, extracting the first two principal components
pca = PCA(n_components=4)
principal_components = pca.fit_transform(X_standardized)
# Get the scores for each sample along the first two principal components
cols_pca = ['PCA1', 'PCA2', 'PCA3', 'PCA4']
df_data[cols_pca] = pca.transform(X_standardized)

vars_to_plot += cols_pca

print("Explained variance ratio:", pca.explained_variance_ratio_)


# # Perform UMAP on the standardized data
# import umap
# umap_model = umap.UMAP(n_components=2)
# umap_components = umap_model.fit_transform(X_standardized)
# cols_umap = ['UMAP1', 'UMAP2'] # , 'PCA3', 'PCA4']
# df_data[cols_umap] = umap_components
# vars_to_plot += cols_umap

Explained variance ratio: [0.3034257  0.15001288 0.10481791 0.08907935]


In [25]:
name_dict = {'vol-ctx--anterior': 'Anterior cortex',
                'vol-thalamus': 'Thalamus', 
             'vol-hippocampus': 'Hippocampus',
                'vol-ctx--isthmuscingulate': 'Isthmus', 
             'vol-amygdala': 'Amygdala', 
             'vol-brain-stem': 'Brainstem',
                'slowdelta_bandpower_total': 'Total delta', 
             'so_rate_f': 'Slow oscillations',
                'ss_dens_f': 'Slow spindles', 
             'fs_dens_c': 'Fast spindles',
                'perc_r': 'REM', 
             'alpha_bandpower_mean_o_w': 'Mean alpha',
             'vol-total_ventricle': 'Ventricles',
            }


unit_dict = {'vol-ctx--anterior': 'Vol (%)',
                'vol-thalamus': 'Vol (%)',
             'vol-hippocampus': 'Vol (%)',
                'vol-ctx--isthmuscingulate': 'Vol (%)',
             'vol-amygdala': 'Vol (%)',
             'vol-brain-stem': 'Vol (%)',
                'slowdelta_bandpower_total': r'Power(mV$^2$)', 
             'vol-total_ventricle': 'Vol (%)',
             'so_rate_f': 'Rate (/min N2+N3)',
                'ss_dens_f': 'Rate (/min N2)', 
             'fs_dens_c': 'Rate (/min N2)',
                'perc_r': 'Fraction (%)',
             'alpha_bandpower_mean_o_w':r'Power(uV$^2$)',
             'PCA1': '',
             'PCA2': '',
             'PCA3': '',
             'PCA4': '',
             'UMAP1': '',
             'UMAP2': '',
            }

In [26]:
df_data[['vol-ctx--anterior', 'vol-thalamus', 'vol-hippocampus',
       'vol-ctx--isthmuscingulate', 'vol-amygdala', 'vol-brain-stem', 'vol-total_ventricle',
       'perc_r']] *= 100 # in percentage

In [27]:
df_data['slowdelta_bandpower_total'] /= 1e6 #uV^2 to mV^2
vars_to_plot.remove('alpha_bandpower_mean_o_w')
vars_to_plot

['vol-thalamus',
 'vol-hippocampus',
 'vol-ctx--anterior',
 'vol-ctx--isthmuscingulate',
 'vol-amygdala',
 'vol-brain-stem',
 'vol-total_ventricle',
 'slowdelta_bandpower_total',
 'so_rate_f',
 'ss_dens_f',
 'fs_dens_c',
 'perc_r',
 'PCA1',
 'PCA2',
 'PCA3',
 'PCA4']

In [28]:
df_control = df_data.loc[sids_control, ['age'] + vars_to_plot]
df_mci = df_data.loc[sids_mci, ['age'] + vars_to_plot]
df_dementia =df_data.loc[sids_dementia, ['age'] + vars_to_plot]
             
%matplotlib inline
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns

plt.rcParams.update({'font.size': 8})
plt.close('all')

alpha_scatter = 0.9
fig, ax = plt.subplots(4, 4, figsize=(7, 7), sharex='col')
ax = ax[:].flatten()


for i in range(len(vars_to_plot)):
        
    feature = vars_to_plot[i]
    try:
        label = name_dict[feature]
    except:
        label = feature
        
    unit = unit_dict[feature]
    
    ax[i].scatter(df_control['age'], df_control[feature], 
               s=1, color='mediumaquamarine', alpha=alpha_scatter, marker='o')
    ax[i].scatter(df_mci['age'], df_mci[feature], 
               s=1, color='orange', alpha=alpha_scatter, marker='o')
    ax[i].scatter(df_dementia['age'], df_dementia[feature], 
               s=1, color='tomato', alpha=alpha_scatter, marker='o')
    
    
    order = 1
    if feature in ['vol-total_ventricle']:
        order = 2
    sns.regplot(x=df_control['age'], y=df_control[feature], ci=95, n_boot=1000, seed=43, order=order, x_ci=5,
                color='mediumaquamarine', scatter=False, ax=ax[i], scatter_kws={'alpha': 1}, label='Control')
    sns.regplot(x=df_mci['age'], y=df_mci[feature], ci=95, n_boot=1000, seed=43, order=order, x_ci=5,
                color="orange", scatter=False, ax=ax[i], label='MCI')
    sns.regplot(x=df_dementia['age'], y=df_dementia[feature], ci=95, n_boot=1000, seed=43, order=order, x_ci=5,
                color="tomato", scatter=False, ax=ax[i], label='Dementia')

    all_y_vals = np.concatenate([df_control[feature], 
                                 df_mci[feature],
                                 df_dementia[feature],
                                ])
    if feature == 'amygdala':
        ax[i].set_yticks(np.arange(0, 1, 0.1))
    if feature == 'ctx--isthmuscingulate':
        ax[i].set_yticks(np.arange(0, 1, 0.1)) 
    if feature == 'brain-stem':
        ax[i].set_yticks(np.arange(0, 5, 0.5)) 
        
    ax[i].set_ylim(np.percentile(all_y_vals, 0.5), np.percentile(all_y_vals, 99.5))
    ax[i].set_title(label, pad=-3)                                                                    
    ax[i].set_ylabel(unit, labelpad=0.1)
    ax[i].tick_params(length=1.5, pad=0, labelsize=8)
    ax[i].set_xlim([40, 95])
    
    ax[i].set_xlabel('')
    if i >= 12:
        ax[i].set_xlabel('Age (years)')
        

        
handles, labels = ax[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=3, frameon=False, bbox_to_anchor=(0.5, 1))

sns.despine(left=True, bottom=True)

plt.subplots_adjust(wspace=0.2, left=0.06, right=1, top=0.92, bottom=0.05)
fig.align_ylabels(ax)

plt.savefig('./figures/mri_eeg_dementia_control_3groups.png', dpi=600)

In [29]:
df_control = df_data.loc[sids_control, ['age'] + vars_to_plot]
df_mci_dementia = df_data.loc[list(sids_mci) + list(sids_dementia), ['age'] + vars_to_plot]

from matplotlib.ticker import FormatStrFormatter
import seaborn as sns

plt.rcParams.update({'font.size': 8})
plt.close('all')

# alpha_scatter = 0.5
fig, ax = plt.subplots(4, 4, figsize=(7, 7), sharex='col')
ax = ax[:].flatten()


for i in range(len(vars_to_plot)):
        
    feature = vars_to_plot[i]
    try:
        label = name_dict[feature]
    except:
        label = feature
        
    unit = unit_dict[feature]
    
    ax[i].scatter(df_control['age'], df_control[feature], 
               s=1, color='mediumaquamarine', alpha=alpha_scatter, marker='o')
    ax[i].scatter(df_mci_dementia['age'], df_mci_dementia[feature], 
               s=1, color='tomato', alpha=alpha_scatter, marker='o')
    
    order = 1
    if feature in ['vol-total_ventricle']:
        order = 2
    sns.regplot(x=df_control['age'], y=df_control[feature], ci=95, n_boot=1000, seed=43, order=order, x_ci=5,
                color='mediumaquamarine', scatter=False, ax=ax[i], scatter_kws={'alpha': 1}, label='Control')
    sns.regplot(x=df_mci_dementia['age'], y=df_mci_dementia[feature], ci=95, n_boot=1000, seed=43, order=order, x_ci=5,
                color="tomato", scatter=False, ax=ax[i], label='MCI and Dementia')

    all_y_vals = np.concatenate([df_control[feature], 
                                 df_mci[feature],
                                 df_dementia[feature],
                                ])
    if feature == 'amygdala':
        ax[i].set_yticks(np.arange(0, 1, 0.1))
    if feature == 'ctx--isthmuscingulate':
        ax[i].set_yticks(np.arange(0, 1, 0.1)) 
    if feature == 'brain-stem':
        ax[i].set_yticks(np.arange(0, 5, 0.5)) 
        
    ax[i].set_ylim(np.percentile(all_y_vals, 0.5), np.percentile(all_y_vals, 99.5))
    ax[i].set_title(label, pad=-3)                                                                    
    ax[i].set_ylabel(unit, labelpad=0.1)
    ax[i].tick_params(length=1.5, pad=0, labelsize=8)
    ax[i].set_xlim([40, 95])
    
    ax[i].set_xlabel('')
    if i >= 12:
        ax[i].set_xlabel('Age (years)')

        
handles, labels = ax[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=2, frameon=False, bbox_to_anchor=(0.5, 1))

sns.despine(left=True, bottom=True)

plt.subplots_adjust(wspace=0.2, left=0.06, right=1, top=0.92, bottom=0.05)
fig.align_ylabels(ax)
    
    

plt.savefig('./results/mri_eeg_dementia_control_2groups.png', dpi=600)

