# Functions and imports

In [None]:
import time
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from pandas.plotting import scatter_matrix
import os
import re
import pandas as pd
import numpy as np
from IPython.display import Audio
import seaborn as sns

import helpers # this is where the main training/decoding functions are, modified from teh original HIVAE main.py

#import warnings 
#warnings.filterwarnings('ignore') ########## NOTE: comment out for testing in case it's hiding problems

def set_settings(opts,nepochs=500,modload=False,save=True): # note: modload doesnt do anything right now, hardcoded in helpers.py
    'replace setting template placeholders with file info'
    inputf=re.sub('.csv','',opts['files'].iloc[0])
    missf=inputf+'_missing.csv'
    typef=inputf+'_types.csv'
    
    template = '--epochs NEPOCHS --model_name model_HIVAE_inputDropout --restore MODLOAD \
        --data_file data_python/INPUT_FILE.csv --types_file data_python/TYPES_FILE \
         --batch_size NBATCH --save NEPFILL --save_file SAVE_FILE\
        --dim_latent_s SDIM --dim_latent_z 1 --dim_latent_y YDIM \
        --miss_percentage_train 0 --miss_percentage_test 0 \
        --true_miss_file data_python/MISS_FILE --learning_rate LRATE'
    
    # replace placeholders in template
    settings = re.sub('INPUT_FILE',inputf,template)
    settings = re.sub('NBATCH',str(opts['nbatch'].iloc[0]),settings)
    settings = re.sub('NEPOCHS',str(nepochs),settings)
    settings = re.sub('NEPFILL',str(nepochs-1),settings) if save else re.sub('NEPFILL',str(nepochs*2),settings)
    settings = re.sub('YDIM',str(opts['ydims'].iloc[0]),settings)
    settings = re.sub('SDIM',str(opts['sdims'].iloc[0]),settings)
    settings = re.sub('MISS_FILE',missf,settings) if not 'medhist' in inputf else re.sub('--true_miss_file data_python/MISS_FILE','',settings)
    settings = re.sub('TYPES_FILE',typef,settings)
    settings = re.sub('SAVE_FILE',inputf,settings)
    settings = re.sub('LRATE',str(opts['lrates'].iloc[0]),settings)
    settings = re.sub('MODLOAD','1',settings) if modload else re.sub('MODLOAD','0',settings)
    
    return settings

# General settings

In [None]:
sample_size=362
# get file list
files=[i for i in os.listdir('data_python/') if not '_type' in i and not '_missing' in i]
sds=[1]*20 + [2] + [1]*13
sdims=dict(zip(files,sds))
best_hyper=pd.read_csv('results_PPMI.csv')
if any(files!=best_hyper['files']):
    print('ERROR!!')
else:
    best_hyper['sdims']=sds
best_hyper

# VP decoding

Run after bnet.R

In [None]:
VPcodes = ### path to virtual patient codes

dfs=list()
virt=list()
for f in files:
    # replace placeholders in template
    opts=dict(best_hyper[best_hyper['files'].copy()==f])
    opts['nbatch'].iloc[0]=sample_size
    settings=set_settings(opts,nepochs=1,modload=True,save=False)
    
    #run
    zcodes=VPcodes['zcode_'+re.sub('.csv','',f)]
    scodes=VPcodes['scode_'+re.sub('.csv','',f)] if 'scode_'+re.sub('.csv','',f) in VPcodes.columns else np.zeros(zcodes.shape)
        
    dec=helpers.dec_network(settings,zcodes,scodes,VP=True)
    subj=pd.read_csv('python_names/'+re.sub('.csv','',f)+'_subj.csv')['x']
    names=pd.read_csv('python_names/'+re.sub('.csv','',f)+'_cols.csv')['x']
    dat=pd.DataFrame(dec)
    dat.columns=names
    dat['SUBJID']=subj
    virt.append(dec)
    dfs.append(dat)

virt_dic=dict(zip(files,virt))
decoded=helpers.merge_dat(dfs)
decoded.to_csv('decodedVP.csv',index=False)

Get Loglikelihoods for R plot!

In [None]:
VPcodes = ### path to virtual patient codes
dfs=list()
for f in files:
    # replace placeholders in template
    opts=dict(best_hyper[best_hyper['files'].copy()==f])
    opts['nbatch'].iloc[0]=sample_size
    settings=set_settings(opts,nepochs=1,modload=True,save=False)
    
    #run
    zcodes=VPcodes['zcode_'+re.sub('.csv','',f)]
    scodes=VPcodes['scode_'+re.sub('.csv','',f)] if 'scode_'+re.sub('.csv','',f) in VPcodes.columns else np.zeros(zcodes.shape)
        
    loglik=helpers.dec_network_loglik(settings,zcodes,scodes,VP=True)
    loglik=np.nanmean(np.array(loglik).T,axis=1)
    subj=pd.read_csv('python_names/'+re.sub('.csv','',f)+'_subj.csv')['x']
    dat=pd.DataFrame(loglik)
    dat.columns=[f]
    dat['SUBJID']=subj
    dfs.append(dat)

decoded=helpers.merge_dat(dfs)
decoded.to_csv('virtual_logliks.csv',index=False)

# Counterfactuals

Run counteractuals_bnlearn.R before running this!

Age on UPDRS - decoded in batches

In [None]:
%%capture
f='UPDRS_VIS00.csv'
VPcodes = pd.read_csv('../../data/data_out/counter_updrs_age.csv')
# replace placeholders in template
opts=dict(best_hyper[best_hyper['files'].copy()==f])
opts['nbatch'].iloc[0]=sample_size
settings=set_settings(opts,nepochs=1,modload=True,save=False)

#run
zcodes=VPcodes['dv']
scodes=np.zeros(zcodes.shape)

decs=list()
n=362
for i in range(int(len(VPcodes['dv'])/n)):
    dec=helpers.dec_network(settings,zcodes[i*n:(i*n+n)],scodes[i*n:(i*n+n)],VP='nomiss');
    decs.append(dec)

In [None]:
names=pd.read_csv('python_names/UPDRS_VIS00_cols.csv')['x']
allPT=pd.DataFrame(np.vstack(decs))
allPT.columns=names
allPT['Intervention']=VPcodes['level']
dfm = allPT.melt(var_name='columns',id_vars='Intervention')
g = sns.FacetGrid(dfm, col='columns',hue='Intervention',col_wrap=3,sharex=False, sharey=False)
g = (g.map(sns.distplot, 'value')).add_legend()
allPT.to_csv('CF_output.csv',index=False)

In [None]:
group = ['No intervention', 'Age -20yrs', 'Age +20yrs']

for group in group:
    subset = allPT[allPT['Intervention'] == group]
    sns.distplot(subset['UPDRS_UPDRS_VIS00'], hist = True, kde = True,label=group)
    
plt.legend(title = 'Intervention')
plt.xlabel('UPDRS total')
plt.ylabel('Density')
