In [1]:
%env MKL_NUM_THREADS=4
%env MKL_DEBUG_CPU_TYPE=5
%env THEANO_FLAGS=device=cpu, floatX=float32

import pickle
import pandas as pd
import numpy as np
import pymc3 as pm
import arviz as az
import theano.tensor as tt
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from collections import Counter
from pathlib import Path
import os
import warnings
warnings.filterwarnings('ignore')
import logging
logging.getLogger('matplotlib').setLevel(level=logging.CRITICAL)

plt.rcParams['font.family'] = 'Nimbus Sans'

INPUT_DATA_AUX_STATS = '../data preparation/all_patients_firstDiag_03292021_otherStats.bpkl3'
TIME_DIFF_DATA = '../data preparation/two_dis_time_diff_seqs_07152022.bpkl3'
OUTCOME_DIS = 'OCD'
# INTERVENTION = 'Abnormal_Spine_Curvature'
TIME_DIFF_UNIT = 100
TIME_DIFF_MIM_THRESHOLD = 100
NUM_CHAINS = 4

HDI_PROB = 0.95

with open(INPUT_DATA_AUX_STATS, 'rb') as f:
    all_aux_stats = pickle.load(f)
    DIS_IDX_DICT = all_aux_stats['dis_to_idx']
    DIS_SYS_DICT = all_aux_stats['dis_to_sys']
    
with open(TIME_DIFF_DATA, 'rb') as f:
    ALL_DIS_PAIR_TIME_DIFF = pickle.load(f)

dis_name_list = []
shift_mean_list = []
shift_hdi_list = []
shift_sig_list = []
alter_waic_list = []
null_waic_list = []
alter_better_list = []
for k, v in DIS_IDX_DICT.items():
    
    INTERVENTION = k
    TIME_DIFF = ALL_DIS_PAIR_TIME_DIFF[(OUTCOME_DIS, 
                                        INTERVENTION, 
                                        DIS_IDX_DICT[OUTCOME_DIS],
                                        DIS_IDX_DICT[INTERVENTION])]
    
    print('\033[94m'+OUTCOME_DIS+' x '+INTERVENTION+'\033[0m')
    
    if (len(TIME_DIFF)) < TIME_DIFF_MIM_THRESHOLD or (INTERVENTION == OUTCOME_DIS):
        continue
        
    dis_name_list.append(INTERVENTION)
    
    data_out_path = f'RDD_experiment_{OUTCOME_DIS}_07152022/{INTERVENTION}'
    
    with open(data_out_path+'/analysis_results_and_plot_data.bpkl3', 'rb') as f:
        rdd_result = pickle.load(f)
        alter_trace = rdd_result['alter.trace']
        alter_waic = rdd_result['alter.waic']
        null_trace = rdd_result['null.trace']
        null_waic = rdd_result['null.waic']
        shift_alter_trace = alter_trace['shift']
    
    shift_mean = shift_alter_trace.mean()
    shift_hdi = az.hdi(shift_alter_trace, hdi_prob=HDI_PROB)
    shift_sig = (shift_hdi[0] * shift_hdi[1] > 0.0)
    
    shift_mean_list.append(shift_mean)
    shift_hdi_list.append(list(shift_hdi))
    shift_sig_list.append(shift_sig)
    
    alter_waic_list.append(alter_waic.waic)
    null_waic_list.append(null_waic.waic)
    alter_better_list.append(alter_waic.waic<null_waic.waic)

env: MKL_NUM_THREADS=4
env: MKL_DEBUG_CPU_TYPE=5
env: THEANO_FLAGS=device=cpu, floatX=float32
[94mOCD x ADHD[0m
[94mOCD x Abnormal_Spine_Curvature[0m
[94mOCD x Acne[0m
[94mOCD x Acquired_Coagulation_Defect[0m
[94mOCD x Acquired_Hemolytic_Anemias[0m
[94mOCD x Acquired_Hypothyroidism[0m
[94mOCD x Acquired_Limb_Deformities[0m
[94mOCD x Acquired_Other_Myopathies[0m
[94mOCD x Acquired_Retinal_Defects[0m
[94mOCD x Acquired_Visual_Disturbances[0m
[94mOCD x Acrocephalosyndactyly[0m
[94mOCD x Acute_Bronchitis[0m
[94mOCD x Acute_Glomerulonephritis[0m
[94mOCD x Acute_Renal_Failure[0m
[94mOCD x Acute_Sinusitis[0m
[94mOCD x Acute_Upper_Respiratory_Infection[0m
[94mOCD x Addisons_Disease[0m
[94mOCD x Adjustment_Disorder[0m
[94mOCD x Adrenogenital_Disorder[0m
[94mOCD x Allergic_Rhinitis[0m
[94mOCD x Alopecia[0m
[94mOCD x Alpha-1-Antitrypsin_Deficiency[0m
[94mOCD x Alveolar_Disease[0m
[94mOCD x Alzheimers_Disease[0m
[94mOCD x Amino_Acid_Transport_Disord

In [2]:
rdd_result_df = pd.DataFrame({'Intervention': dis_name_list,
                              'Shift mean': shift_mean_list,
                              'Shift HDI': shift_hdi_list,
                              'Shift HDI significance': shift_sig_list,
                              'Shifted model WAIC': alter_waic_list,
                              'Non-shifted model WAIC': null_waic_list,
                              'Shifted mdodel better?': alter_better_list})

In [3]:
rdd_result_df.to_csv('RDD_experiment_Schizophrenia_06202022_linearTrend_result.csv')
rdd_result_df.sort_values(by=['Shift mean'])

Unnamed: 0,Intervention,Shift mean,Shift HDI,Shift HDI significance,Shifted model WAIC,Non-shifted model WAIC,Shifted mdodel better?
29,Gestational_Pregnancy_Related_Disorder,-18.034492,"[-31.275208, 1.9894949]",False,229.912589,233.264027,True
18,Ear_Infection,-4.290842,"[-12.367907, 2.8115604]",False,327.494104,328.34702,True
2,Acute_Bronchitis,-4.206587,"[-12.750944, 2.1288836]",False,328.523127,331.71213,True
4,Acute_Upper_Respiratory_Infection,-1.619924,"[-7.7698574, 3.4477892]",False,323.825353,324.593307,True
19,Esophageal_Disease,-1.224313,"[-4.3902235, 1.8760378]",False,214.470921,215.591795,True
31,Influenza,-0.838387,"[-4.223043, 2.7974114]",False,223.929368,226.963983,True
50,Viral_Warts_HPV,-0.360516,"[-2.5306785, 1.5840946]",False,167.238823,166.932091,False
9,Atopic_Contact_Dermatitis,-0.23373,"[-3.5203292, 3.0711532]",False,282.815549,282.134361,False
41,Oro-Facial_Congenital_Anomaly,-0.112894,"[-4.135776, 2.6248693]",False,191.871846,191.422231,False
25,Fungal_Infection,-0.042255,"[-1.8198237, 2.041373]",False,231.790428,231.128041,False


In [4]:
rdd_result_df_sig = rdd_result_df.loc[rdd_result_df['Shift HDI significance']==True]
rdd_result_df_sig = rdd_result_df_sig.sort_values(by=['Shift mean'])
rdd_result_df_sig.reset_index(inplace=True)
rdd_result_df_sig.to_csv('RDD_experiment_Schizophrenia_06202022_linearTrend_sig_result.csv')
rdd_result_df_sig

Unnamed: 0,index,Intervention,Shift mean,Shift HDI,Shift HDI significance,Shifted model WAIC,Non-shifted model WAIC,Shifted mdodel better?
0,44,Speech_Language_Disorder,3.980531,"[1.8444709, 5.90592]",True,181.663969,192.048439,True
1,5,Adjustment_Disorder,4.062031,"[0.82260656, 6.9315577]",True,156.536138,167.872312,True
2,38,Non-Specific_Pain,4.080336,"[1.408563, 7.514552]",True,209.903595,215.11367,True
3,12,Chronic_Sinusitis,5.160396,"[3.0415404, 7.143796]",True,171.659757,181.580647,True
4,6,Allergic_Rhinitis,6.172914,"[0.10910781, 10.371922]",True,279.758836,278.369277,False
