# Comparing Ketamine and Midazolam after treatment in ROIs
- Analysis in HPC
- focus on end of treatment
- Amygdala
- vmPFC
- Hippocampus

#### Small explanation for the code for future reference (as it is a bit messy)
- I take all functional files of trauma vs relaxed first script first 1min
- I mask for amygdala - run analysis per session (1,2,3) and calculate effect (using pyMC3)
- I mask for vmPFC and do the same
- I mask for hippocampus and do the same
- Effect reported are amygdala and hippocampus show sig. difference in the post treatment scan. Amygdala shows effect in 30 days f/u hipocampus doesn't. vmPFC and OFC doesn't show anything.

In [None]:
# import relevant packages
import glob
import numpy as np
import scipy
import nilearn
import nilearn.image
import nilearn.plotting
import nilearn.input_data
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pymc3 as pm
import arviz as az
from pymc3.glm import GLM

In [None]:
# Set session
ses = 3
## Grab group
# compare between groups

medication_cond = pd.read_csv('kpe_sub_condition.csv')

func_files = glob.glob('/gpfs/gibbs/pi/levy_ifat/Or/kpe/results/ScriptPart_ses%s/modelfit/_subject_id_*/modelestimate/results/cope7.nii.gz' %(ses))

func_files.sort()
len(func_files)

In [None]:
# remove 1315
del func_files[6] #(only session 1,2)

In [None]:
## Amygdala as mask
mask_file = '/gpfs/gibbs/pi/levy_ifat/Or/ROI/amygdala_association-test_z_FDR_0.01.nii.gz'
mask_file = nilearn.image.math_img("a>=25", a=mask_file)
%matplotlib inline
nilearn.plotting.plot_roi(mask_file)


masker = nilearn.input_data.NiftiMasker(mask_img=mask_file, 
                                smoothing_fwhm=None, standardize=False,
                                        detrend=False, verbose=9).fit()

In [None]:

t_arr = []
mean_act = []
scr_id = []
#delayed_get_data = dask.delayed(masker.fit_transform)
for func in func_files:
    # get subject number
    scr_id.append('KPE' + func.split('id_')[1].split('/')[0])
    # get average activation
    t_map = masker.transform(func)
    t_arr.append(np.mean(t_map, axis=1)[0])
    


In [None]:
df_ses3 = []
df_ses3 = pd.DataFrame({'scr_id': scr_id, 'amg3': t_arr})
df_ses3 = pd.merge(medication_cond, df_ses3)
df_ses3 = df_ses3.rename(columns={'med_cond': 'group'})
#df['group'] = medication_cond['med_cond']
df_ses3 = df_ses3.replace(to_replace={'group': {0.0:'midazolam', 1.0:'ketamine'}})

In [None]:
df_ses3.groupby('group').describe(percentiles=[.025, 0.975])
#df_ses3.groupby('group').median()

In [None]:
df_ses2 = []
df_ses2 = pd.DataFrame({'scr_id': scr_id, 'amg2': t_arr})
df_ses2 = pd.merge(medication_cond, df_ses2)
df_ses2 = df_ses2.rename(columns={'med_cond': 'group'})
#df['goup'] = medication_cond['med_cond']
df_ses2 = df_ses2.replace(to_replace={'group': {0.0:'midazolam', 1.0:'ketamine'}})

In [None]:
df_ses2.groupby('group').describe(percentiles=[.025, 0.975])
#df_ses2.groupby('group').median()

In [None]:
sns.set_style("ticks")
sns.boxplot('group','amg2',data=df_ses2)
sns.stripplot('group','amg2',data=df_ses2)
scipy.stats.ttest_ind(df_ses2.amg2[df_ses2.group=='ketamine'],
                      df_ses2.amg2[df_ses2.group=='midazolam']
                     )

In [None]:
df_ses1 = []
df_ses1 = pd.DataFrame({'scr_id': scr_id, 'amg1': t_arr})
df_ses1 = pd.merge(medication_cond, df_ses1)
df_ses1 = df_ses1.rename(columns={'med_cond': 'group'})
df_ses1 = df_ses1.replace(to_replace={'group': {0.0:'midazolam', 1.0:'ketamine'}})
#df['group'] = medication_cond['med_cond']

In [None]:

df_ses1.groupby('group').describe(percentiles=[.025, 0.975])

In [None]:
df_ses1

In [None]:
df_ses1 = pd.merge(df, df_ses1)
df_ses1['amg_change'] = df_ses1.meanAct - df_ses1.amg1

In [None]:
# this is in case we need to show the lowering of amygdala reactivation before and after treatment
sns.boxplot(y='amg1', x= 'group', data = df_ses1)
sns.stripplot(y='amg1', x= 'group', data = df_ses1)

In [None]:
df_ses3 = pd.merge(df, df_ses3)
df_ses3

In [None]:
sns.barplot(x='group',y='amg3', data=df_ses3, ci=95)
#sns.boxplot(x='group',y='meanAct', data=df)
scipy.stats.ttest_ind(df_ses3.amg3[df_ses3['group']=='ketamine'], 
                      df_ses3['amg3'][df_ses3['group']=='midazolam'])

In [None]:
df_ses1 = []
df_ses1 = pd.DataFrame({'scr_id': scr_id, 'meanAct_ses1': average[0]})
df_ses1 = pd.merge(medication_cond, df_ses1)
df_ses1 = df_ses1.rename(columns={'med_cond': 'group'})
#df['group'] = medication_cond['med_cond']
df_ses1 = df_ses1.replace(to_replace={'group': {0.0:'midazolam', 1.0:'ketamine'}})

In [None]:
group = {'ketamine': 1,'midazolam': 0} 
df_ses3['groupIdx'] =[group[item] for item in df_ses3.group] 

In [None]:
df = []
df = pd.DataFrame({'scr_id': scr_id, 'meanAct': t_arr})
df = pd.merge(medication_cond, df)
df = df.rename(columns={'med_cond': 'group'})
#df['group'] = medication_cond['med_cond']
df = df.replace(to_replace={'group': {0.0:'midazolam', 1.0:'ketamine'}})

## Combine all three sessions

In [None]:
df = pd.merge(df_ses1, df_ses2, how='left')
df = pd.merge(df, df_ses3, how='left') # adding that to avoid removing Nan subjects
# add index for group (0 and 1)
group = {'ketamine': 1,'midazolam': 0} 
df['groupIdx'] =[group[item] for item in df.group] 
# save the dataframe
df.to_csv('threeSessions_amg_TraumavsRelax.csv', index = False)

In [None]:
df.groupby(['group']).describe()

In [None]:
# plot
sns.barplot(x='group',y='amg2', data=df, ci=95)
#sns.boxplot(x='group',y='meanAct', data=df)
scipy.stats.ttest_ind(df.amg2[df['group']==1], df['amg2'][df['group']==0])

In [None]:
# test changes betwen sessions
df2ses = pd.merge(df, df_ses1)
df2ses['amg2_1'] = df2ses.meanAct - df2ses.meanAct_ses1


In [None]:
sns.barplot(x='group',y='amg2_1', data=df2ses, ci=68)
#sns.boxplot(x='group',y='meanAct', data=df)
scipy.stats.ttest_ind(df2ses.amg2_1[df2ses['group']=='ketamine'], 
                      df2ses['amg2_1'][df2ses['group']=='midazolam'])

## Use PyMC3 for bayesian based analysis 

In [None]:
# first code new variable for group index (1=ketamine, 0= midazolam)
group = {'ketamine': 1,'midazolam': 0} 
df['groupIdx'] =[group[item] for item in df.group] 

In [None]:
# Full model
with pm.Model() as model_1:
    # Data
    group = pm.Data('group', df.groupIdx)
    amg = pm.Data('amg', df.meanAct)
    #ketamine = pm.Data('ketamine', df.meanAct[df['group']=='ketamine'].values)
    #midazolam = pm.Data('midazolam', df.meanAct[df['group']=='midazolam'].values)
    
    # Priors
    alpha = pm.Normal('alpha', mu=5, sd=5)
    beta = pm.Normal('beta', mu=-5, sd=5)
    sigma = pm.HalfNormal('sigma', sigma=5)
    
    # Regression
    mu = alpha + beta * group
    diff_group = pm.Normal('diff_group', mu=mu, sd=sigma, observed=amg)
    
    # Prior sampling, trace definition and posterior sampling
    prior = pm.sample_prior_predictive()
    posterior_1 = pm.sample(draws=4000, tune=4000) # this is the trace sampling
   # posterior_pred_1 = pm.sample_posterior_predictive(posterior_1)

In [None]:
#az.summary(posterior_1, credible_interval=.95).round(2) # adding round to make shorted floats
pm.summary(posterior_1, hdi_prob=0.95)#, alpha=.05).round(2)# also possible

In [None]:
# play with glm module of pymc3
with pm.Model() as model_glm:
    GLM.from_formula('amg3 ~ groupIdx', df_ses3)
    trace = pm.sample(draws=5000, tune=3000)

In [None]:
pm.summary(trace, hdi_prob=.95).round(2)

In [None]:
pm.plot_posterior(trace['groupIdx'])

In [None]:
sns.distplot(trace.groupIdx)
sum(trace['groupIdx']>0) / len(trace['groupIdx'])

In [None]:
# set variables
sns.set_style("ticks") # set style
y = 'meanAct'
dfPlot = df
ci = np.quantile(trace.groupIdx, [.025,.975])
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(3, 5),gridspec_kw={'width_ratios': [1, .2],
                                                        'wspace':.1})
g1 = sns.stripplot(y= y, x='group', data=dfPlot, size = 8, ax=ax1)
sns.boxplot(y= y, x='group', data=dfPlot,  ax=ax1,
            boxprops=dict(alpha=.3))
g2 = sns.distplot(trace['groupIdx'], ax = ax2, vertical=True)
ax2.vlines(x=0.001,ymin=ci[0], ymax=ci[1], color='black', 
           linewidth = 2, linestyle = "-")

#g3.set_ylim(-.7, .7)
#ax1.set_ylim(-.7,.7)
ax2.set_ylim(g1.get_ylim()) # use first graph's limits to get the relevant for this one
ax2.yaxis.tick_right()
ax2.set_xticks([])
ax2.set_ylabel("Difference between groups", fontsize=14) 
ax2.yaxis.set_label_position("right")
ax1.set_ylabel("Amg reactivity to traumatic script", fontsize=12)
ax1.set_xlabel("Group", fontsize=14)
fig.savefig('amygdalaReactivity.png', dpi=600, bbox_inches='tight')

### Creating mixed level model

In [None]:
df_long = pd.melt(df, id_vars=['scr_id','groupIdx'], value_vars=['amg1','amg2','amg3'])
df_long.to_csv('amygdala.csv', index=False)

In [None]:
with pm.Model() as model_glm1:
    GLM.from_formula('value ~ groupIdx + variable', df_long)
    trace_mixed = pm.sample(draws=2000, tune=2000)

In [None]:
pm.summary(trace_mixed)

#### There is a main effect for the group (ketamine lower than midazolam)

## Next we do the same for vmPFC

In [None]:
# now lets do the same with vmPFC
mask_file = '/gpfs/gibbs/pi/levy_ifat/Or/ROI/vmpfc_association-test_z_FDR_0.01.nii.gz'
mask_file = nilearn.image.math_img("a>=5", a=mask_file)
%matplotlib inline
nilearn.plotting.plot_roi(mask_file)
masker = nilearn.input_data.NiftiMasker(mask_img=mask_file, 
                               sessions=None, smoothing_fwhm=None,
                                        standardize=False, detrend=False, verbose=5)

In [None]:
mean_act_vmpfc = []
scr_id = []
for func in func_files:
    # get subject number
    scr_id.append('KPE' + func.split('id_')[1].split('/')[0])
    # get average activation
    t_map = masker.fit_transform(func)
    
    average = np.mean(np.array(t_map))
    mean_act_vmpfc.append(average)


In [None]:
df["vmpfc"] = mean_act_vmpfc
#df_vmpfc = pd.DataFrame({'scr_id': scr_id, 'vmpfc': mean_act_vmpfc})
#df_vmpfc = pd.merge(df_vmpfc, medication_cond)
sns.boxplot(x='group',y='vmpfc', data=df)
#sns.barplot(x='med_cond',y='meanAct', data=df_vmpfc, ci=68)
scipy.stats.ttest_ind(df.vmpfc[df['group']=='ketamine'],
                      df.vmpfc[df['group']=='midazolam'])

In [None]:
df_vmpfc

In [None]:
with pm.Model() as model_glm:
    GLM.from_formula('vmPFC ~ groupIdx', df)
    trace_vmpfc = pm.sample(draws=4000, tune=3000)

In [None]:
pm.summary(trace_vmpfc, credible_interval=.95).round(2)

In [None]:
## Hippocampus
mask_file = '/gpfs/gibbs/pi/levy_ifat/Or/ROI/hippocampus_association-test_z_FDR_0.01.nii.gz'
mask_file = nilearn.image.math_img("a>=15", a=mask_file)
%matplotlib inline
nilearn.plotting.plot_roi(mask_file)
masker = nilearn.input_data.NiftiMasker(mask_img=mask_file, 
                               sessions=None, smoothing_fwhm=None,
                                        standardize=False, detrend=False, verbose=5)

In [None]:
mean_act_hippo = []
scr_id = []
for func in func_files:
    # get subject number
    scr_id.append('KPE' + func.split('id_')[1].split('/')[0])
    # get average activation
    t_map = masker.fit_transform(func)
    
    average = np.mean(np.array(t_map))
    mean_act_hippo.append(average)


In [None]:
df_ses3['hippo3'] = mean_act_hippo

In [None]:
df_ses3.groupby('group').describe()

In [None]:
sns.barplot(x='group',y='hippo3', data=df_ses3, ci=95)
scipy.stats.ttest_ind(df_ses3.hippo3[df_ses3['group']=='ketamine'],
                      df_ses3['hippo3'][df_ses3['group']=='midazolam'])

In [None]:
df['hippo_21'] = df.hippo2 - df.hippo1
sns.barplot(x='group',y='hippo_21', data=df, ci=68)
scipy.stats.ttest_ind(df.hippo_21[df['group']=='ketamine'], df['hippo_21'][df['group']=='midazolam'])

In [None]:
with pm.Model() as model_glm:
    GLM.from_formula('hippo3 ~ groupIdx', df_ses3)
    trace_hippo = pm.sample(draws=2000, tune=2000,random_seed=113)
pm.summary(trace_hippo, hdi_prob=.95).round(2)

In [None]:
# set variables
sns.set_style("ticks")
y = 'hippo'
dfPlot = df
ci = np.quantile(trace_hippo.groupIdx, [.025,.975])
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(3, 5),gridspec_kw={'width_ratios': [1, .2],
                                                        'wspace':.1})
g1 = sns.stripplot(y= y, x='group', data=dfPlot, size = 8, ax=ax1)
sns.boxplot(y= y, x='group', data=dfPlot,  ax=ax1,
            boxprops=dict(alpha=.3))
g2 = sns.distplot(trace_hippo['groupIdx'], ax = ax2, vertical=True)
ax2.vlines(x=0.001,ymin=ci[0], ymax=ci[1], color='black', 
           linewidth = 2, linestyle = "-")


ax2.set_ylim(g1.get_ylim()) # use first graph's limits to get the relevant for this one
ax2.yaxis.tick_right()
ax2.set_xticks([])
ax2.set_ylabel("Difference between groups", fontsize=14) 
ax2.yaxis.set_label_position("right")
ax1.set_ylabel("Hippocampus reactivity to traumatic script", fontsize=12)
ax1.set_xlabel("Group", fontsize=14)
fig.savefig('hippoReactivity.png', dpi=600, bbox_inches='tight')

In [None]:
df.groupby('group').describe()