In [1]:
%env MKL_NUM_THREADS=4
%env MKL_DEBUG_CPU_TYPE=5
%env THEANO_FLAGS=device=cpu, floatX=float32

import pickle
import numpy as np
import pymc3 as pm
import arviz as az
import theano.tensor as tt
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from collections import Counter
from pathlib import Path
import os
import warnings
warnings.filterwarnings('ignore')
import logging
logging.getLogger('matplotlib').setLevel(level=logging.CRITICAL)

plt.rcParams['font.family'] = 'Nimbus Sans'

INPUT_DATA_AUX_STATS = '../data preparation/all_patients_firstDiag_03292021_otherStats.bpkl3'
TIME_DIFF_DATA = '../data preparation/two_dis_time_diff_seqs_07152022.bpkl3'
OUTCOME_DIS = 'OCD'
# INTERVENTION = 'Abnormal_Spine_Curvature'
TIME_DIFF_UNIT = 100
TIME_DIFF_MIM_THRESHOLD = 100
NUM_CHAINS = 4

with open(INPUT_DATA_AUX_STATS, 'rb') as f:
    all_aux_stats = pickle.load(f)
    DIS_IDX_DICT = all_aux_stats['dis_to_idx']
    DIS_SYS_DICT = all_aux_stats['dis_to_sys']
    
with open(TIME_DIFF_DATA, 'rb') as f:
    ALL_DIS_PAIR_TIME_DIFF = pickle.load(f)

SKIP = True
for k, v in DIS_IDX_DICT.items():
    
    INTERVENTION = k
    TIME_DIFF = ALL_DIS_PAIR_TIME_DIFF[(OUTCOME_DIS, 
                                        INTERVENTION, 
                                        DIS_IDX_DICT[OUTCOME_DIS],
                                        DIS_IDX_DICT[INTERVENTION])]
    
    print('\033[94m'+OUTCOME_DIS+' x '+INTERVENTION+'\033[0m')
    
    if (len(TIME_DIFF)) < TIME_DIFF_MIM_THRESHOLD or (INTERVENTION == OUTCOME_DIS):
        continue
    
#     if INTERVENTION == 'Rectal_Anal_Disorder': # Start from a particular disease
#         SKIP = False
    
#     if SKIP:
#         continue
        
    # Prepare data ------------------------------------------------------------
    out_data_time_diff_rounded = np.array(TIME_DIFF) // TIME_DIFF_UNIT
    time_diff_counter = Counter(out_data_time_diff_rounded)
    time_diff_xaxis = list(time_diff_counter.keys())
    time_diff_yaxis = [time_diff_counter[t] for t in time_diff_xaxis]
        
    # Safely create output folder ---------------------------------------------
    plot_out_path = f'RDD_experiment_{OUTCOME_DIS}_07152022/{INTERVENTION}/plots'
    data_out_path = f'RDD_experiment_{OUTCOME_DIS}_07152022/{INTERVENTION}'
    Path(plot_out_path).mkdir(parents=True, exist_ok=True)
    
    # Plot raw data points ----------------------------------------------------
    fig, ax = plt.subplots(1, 1)
    ax.plot(time_diff_xaxis, time_diff_yaxis, '.')
    ax.axvline(color='gray')
    ax.text(0.0, 1.1, 'Outcome: '+OUTCOME_DIS, transform=ax.transAxes)
    ax.text(0.0, 1.05, 'Intervention: '+INTERVENTION, transform=ax.transAxes)
    ax.set_ylabel('Number of diagnosis')
    ax.set_xlabel('Time difference between the outcome and the intervention\n'
                  f'(unit = {TIME_DIFF_UNIT} days)')
    plt.savefig(plot_out_path+'/raw_data_points.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    
    # Prepare MCMC data -------------------------------------------------------
    
    data_ord = np.argsort(time_diff_xaxis)
    data_X = np.array(time_diff_xaxis, dtype=np.float64)[data_ord, None]
    data_X_poly = np.hstack((np.ones_like(data_X), data_X, data_X**2.0, data_X**3.0))
    data_X_poly_abs = np.abs(data_X_poly)
    data_y = np.array(time_diff_yaxis, dtype=np.float64)[data_ord]
    data_x_before = 1.0*(data_X.flatten() < 0.0)
    
    # Alternative model, with a shift at point zero ---------------------------
    alter_model = pm.Model()
    
    with alter_model:

        # The RDD shift at point zero
        ss = pm.HalfCauchy('ss', beta=5.0)
        shift = pm.Laplace('shift', mu=0.0, b=ss)

        # Polynomial fit
        b_shrinkage = pm.HalfCauchy('b_shrinkage', beta=1.0, shape=2)
        fixed_b = pm.Normal('fixed_b', mu=0, sigma=20.0*b_shrinkage, shape=2)

        # Gaussian process fit
        ls = pm.HalfCauchy('ls', beta=5.0)
        var = pm.HalfCauchy('var', beta=5.0)
        cov_func = var * pm.gp.cov.ExpQuad(1, ls=ls)
        gp = pm.gp.Latent(cov_func=cov_func)
        f_gp = gp.prior('f_gp', X=data_X)

        poly_abs = pm.Deterministic('poly_abs', tt.dot(data_X_poly_abs[:, :2], fixed_b))

        l = pm.Deterministic('l', f_gp + poly_abs + data_x_before * shift)
        nu_y = pm.Gamma('nu_y', alpha=2.0, beta=1.0)
        sigma_y = pm.Gamma('sigma_y', alpha=2.0, beta=1.0)
        y = pm.StudentT('y', 
                        nu=nu_y,
                        mu=l,
                        sigma=sigma_y,
                        observed=data_y)

    for RV in alter_model.basic_RVs:
        print(RV.name, RV.logp(alter_model.test_point))

    with alter_model:
        start, step = pm.init_nuts(init='advi+adapt_diag', 
                                   chains=NUM_CHAINS,
                                   n_init=200000,
                                   target_accept=0.9)
        alter_trace = pm.sample(draws=500,
                                chains=NUM_CHAINS,
                                tune=1500,
                                start=start,
                                step=step,
                                discard_tuned_samples=True)
    
    # Trace plot --------------------------------------------------------
    az.style.use('arviz-whitegrid')
    az.plot_trace(alter_trace, var_names=['ss', 'shift', 
                                          'b_shrinkage', 'fixed_b', 
                                          'ls', 'var',
                                          'nu_y', 'sigma_y'])
    plt.savefig(plot_out_path+'/alter_model_trace_plot.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    
    alter_summ = az.summary(alter_trace, 
                            var_names=['ss', 'shift', 
                                       'b_shrinkage', 'fixed_b', 
                                       'ls', 'var',
                                       'nu_y', 'sigma_y'])
    alter_summ.to_csv(data_out_path+'/alter_trace_summary.csv')
    
    # Trace data --------------------------------------------------------------
    l_alter_trace = alter_trace['l']
    f_gp_alter_trace = alter_trace['f_gp']
    poly_abs_alter_trace = alter_trace['poly_abs']
    shift_alter_trace = alter_trace['shift']
    fixed_b_alter_trace = alter_trace['fixed_b']
    
    # Energy plot -------------------------------------------------------------
    az.plot_energy(alter_trace)
    plt.savefig(plot_out_path+'/alter_model_energy_plot.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    
    # Plot fitting results ----------------------------------------------------
    
    trend_x_range = np.arange(data_X_poly[0, 1], data_X_poly[-1, 1], 0.1)[:, None]
    trend_x_range_and_bias = np.hstack((np.ones_like(trend_x_range),
                                        trend_x_range))
    trend_y = fixed_b_alter_trace @ np.abs(np.transpose(trend_x_range_and_bias))
    # trend_y = np.exp(trend_y)
    trend_y += np.outer(shift_alter_trace, (trend_x_range < 0).flatten())
    
    waic = az.waic(alter_trace, pointwise=True, scale='deviance')
    loo = az.loo(alter_trace, pointwise=True, scale='deviance')
    shift_hdi = np.float_(az.hdi(shift_alter_trace, hdi_prob=0.95))
    slo, sup = round(shift_hdi[0], 4), round(shift_hdi[1], 4)

    plt.style.use('default')
    plt.rcParams['font.family'] = 'Nimbus Sans'
    fig, ax = plt.subplots(1, 1)
    ax.plot(data_X.flatten(), data_y, '.')
    ax.plot(data_X.flatten(), l_alter_trace.mean(axis=0), label='Posterior fit (95% HDI)')
    az.plot_hdi(data_X.flatten(), l_alter_trace, hdi_prob=0.95, 
                ax=ax, smooth=False, fill_kwargs={'linewidth': 0.0})
    ax.plot(data_X.flatten(), f_gp_alter_trace.mean(axis=0), label='GP tweak estimate')
    ax.plot(trend_x_range.flatten(), trend_y.mean(axis=0), label='Trend + shift estimate')
    ax.text(1.05, 0.75, 
            f'WAIC = {np.around(waic.waic, 4)}', 
            fontsize='medium',
            transform=ax.transAxes)
    ax.text(1.05, 0.7, 
            f'LOO = {np.around(loo.loo, 4)}', 
            fontsize='medium',
            transform=ax.transAxes)
    ax.text(1.05, 0.65,
            f'Shift 95% HDI = ({slo}, {sup})',
            fontsize='medium',
            transform=ax.transAxes)
    ax.legend(loc='upper right', bbox_to_anchor=(1.5, 1.0))

    ax.axvline(color='gray')
    ax.text(0.0, 1.1, 'Outcome: '+OUTCOME_DIS, transform=ax.transAxes)
    ax.text(0.0, 1.05, 'Intervention: '+INTERVENTION, transform=ax.transAxes)
    ax.set_ylabel('Number of diagnosis')
    ax.set_xlabel('Time difference between the outcome and the intervention\n'
                  f'(unit = {TIME_DIFF_UNIT} days)')
    
    plt.savefig(plot_out_path+'/alter_model_posterior_est_plot.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    
    # Null model, without a shift at point zero ---------------------------
    null_model = pm.Model()
    
    with null_model:

        # The RDD shift at point zero
        # Pass

        # Polynomial fit
        b_shrinkage = pm.HalfCauchy('b_shrinkage', beta=1.0, shape=2)
        fixed_b = pm.Normal('fixed_b', mu=0, sigma=20.0*b_shrinkage, shape=2)

        # Gaussian process fit
        ls = pm.HalfCauchy('ls', beta=5.0)
        var = pm.HalfCauchy('var', beta=5.0)
        cov_func = var * pm.gp.cov.ExpQuad(1, ls=ls)
        gp = pm.gp.Latent(cov_func=cov_func)
        f_gp = gp.prior('f_gp', X=data_X)

        poly_abs = pm.Deterministic('poly_abs', tt.dot(data_X_poly_abs[:, :2], fixed_b))

        l = pm.Deterministic('l', f_gp + poly_abs ) #+ data_x_before * shift)
        nu_y = pm.Gamma('nu_y', alpha=2.0, beta=1.0)
        sigma_y = pm.Gamma('sigma_y', alpha=2.0, beta=1.0)
        y = pm.StudentT('y', 
                        nu=nu_y,
                        mu=l,
                        sigma=sigma_y,
                        observed=data_y)

    for RV in null_model.basic_RVs:
        print(RV.name, RV.logp(null_model.test_point))

    with null_model:
        start, step = pm.init_nuts(init='advi+adapt_diag', 
                                   chains=NUM_CHAINS,
                                   n_init=200000,
                                   target_accept=0.9)
        null_trace = pm.sample(draws=500,
                               chains=NUM_CHAINS,
                               tune=1500,
                               start=start,
                               step=step,
                               discard_tuned_samples=True)
        
    # Trace plot --------------------------------------------------------
    az.style.use('arviz-whitegrid')
    az.plot_trace(null_trace, var_names=['b_shrinkage', 'fixed_b', 
                                          'ls', 'var',
                                          'nu_y', 'sigma_y'])
    plt.savefig(plot_out_path+'/null_model_trace_plot.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    
    null_summ = az.summary(null_trace, 
                           var_names=['b_shrinkage', 'fixed_b', 
                                      'ls', 'var',
                                      'nu_y', 'sigma_y'])
    null_summ.to_csv(data_out_path+'/null_trace_summary.csv')
    
    # Trace data --------------------------------------------------------------
    l_null_trace = null_trace['l']
    f_gp_null_trace = null_trace['f_gp']
    poly_abs_null_trace = null_trace['poly_abs']
    fixed_b_null_trace = null_trace['fixed_b']
    
    # Energy plot -------------------------------------------------------------
    az.plot_energy(null_trace)
    plt.savefig(plot_out_path+'/null_model_energy_plot.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    
    # Plot fitting results ----------------------------------------------------
    
    null_trend_y = fixed_b_null_trace @ np.abs(np.transpose(trend_x_range_and_bias))
    # null_trend_y = np.exp(null_trend_y)
    
    null_waic = az.waic(null_trace, pointwise=True, scale='deviance')
    null_loo = az.loo(null_trace, pointwise=True, scale='deviance')

    plt.style.use('default')
    plt.rcParams['font.family'] = 'Nimbus Sans'
    fig, ax = plt.subplots(1, 1)
    ax.plot(data_X.flatten(), data_y, '.')
    ax.plot(data_X.flatten(), l_null_trace.mean(axis=0), label='Posterior fit (95% HDI)')
    az.plot_hdi(data_X.flatten(), l_null_trace, hdi_prob=0.95, 
                ax=ax, smooth=False, fill_kwargs={'linewidth': 0.0})
    ax.plot(data_X.flatten(), f_gp_null_trace.mean(axis=0), label='GP tweak estimate')
    ax.plot(trend_x_range.flatten(), null_trend_y.mean(axis=0), label='Trend + shift estimate')
    ax.text(1.05, 0.75, 
            f'WAIC = {np.around(null_waic.waic, 4)}', 
            fontsize='medium',
            transform=ax.transAxes)
    ax.text(1.05, 0.7, 
            f'LOO = {np.around(null_loo.loo, 4)}', 
            fontsize='medium',
            transform=ax.transAxes)
    ax.legend(loc='upper right', bbox_to_anchor=(1.5, 1.0))

    ax.axvline(color='gray')
    ax.text(0.0, 1.1, 'Outcome: '+OUTCOME_DIS, transform=ax.transAxes)
    ax.text(0.0, 1.05, 'Intervention: '+INTERVENTION, transform=ax.transAxes)
    ax.set_ylabel('Number of diagnosis')
    ax.set_xlabel('Time difference between the outcome and the intervention\n'
                  f'(unit = {TIME_DIFF_UNIT} days)')
    
    plt.savefig(plot_out_path+'/null_model_posterior_est_plot.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    
    # ELPD plot ---------------------------------------------------------------
    plt.style.use('default')
    plt.rcParams['font.family'] = 'Nimbus Sans'
    fig, ax = plt.subplots(1, 1)
    ax.plot(data_X.flatten(), 
            0.5*(null_waic.waic_i-waic.waic_i), '.')
    ax.axvline(color='gray')
    ax.text(0.0, 1.1, 'Outcome: '+OUTCOME_DIS, transform=ax.transAxes)
    ax.text(0.0, 1.05, 'Intervention: '+INTERVENTION, transform=ax.transAxes)
    ax.set_ylabel('ELPD difference (Alternative, shifted - Null, non-shifted)')
    ax.set_xlabel('Time difference between the outcome and the intervention\n'
                  f'(unit = {TIME_DIFF_UNIT} days)')
    plt.savefig(plot_out_path+'/model_compare_elpd.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    
    # Save data 
    with open(data_out_path+'/analysis_results_and_plot_data.bpkl3', 'wb') as f:
        pickle.dump({'raw_data': (TIME_DIFF, TIME_DIFF_UNIT, out_data_time_diff_rounded),
                     'alter.trace': alter_trace,
                     'alter.model': alter_model,
                     'alter.waic': waic,
                     'alter.loo': loo,
                     'null.trace': null_trace,
                     'null.model': null_model,
                     'null.waic': null_waic,
                     'null.loo': null_loo,
                     'plot.trend_data': (trend_x_range, trend_y, null_trend_y)}, f)
    
    # break

env: MKL_NUM_THREADS=4
env: MKL_DEBUG_CPU_TYPE=5
env: THEANO_FLAGS=device=cpu, floatX=float32
[94mOCD x ADHD[0m
ss_log__ -1.1447301
shift -2.3025851
b_shrinkage_log__ -2.28946
fixed_b -7.829342
ls_log__ -1.1447301
var_log__ -1.1447301
f_gp_rotated_ -39.51436
nu_y_log__ -0.61370564
sigma_y_log__ -0.61370564


Initializing NUTS using advi+adapt_diag...


y -171.3781465143124


KeyboardInterrupt: 