In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import rbf_kernel, laplacian_kernel

import warnings
warnings.filterwarnings("ignore")

from SyntheticDataModule import *
from estimators import *
from utils import *

In [None]:
save_df = True
d = 1
rct_size = 1000
m = 4
obs_size = rct_size * m

# RCT data generating model parameters

px_dist_r, px_args_r = 'Gaussian', {'mean': [0], 'cov': [[1]]}
prop_fn_r, prop_args_r = 'sigmoid', {'beta': [0, 1e-4]}
tte_params_r = {'model': 'coxph',
                'hazard': 'weibull',
                'cox_args': {'Y0': {'beta': [0,1], 'lambda': 0.8, 'p': 4},
                            'Y1': {'beta': [0,0.25], 'lambda': 0.4, 'p': 4},
                            'C0': {'beta': [0,0], 'lambda': 0.3, 'p': 4},
                            'C1': {'beta': [0,0], 'lambda': 0.3, 'p': 4},},
                }

# OBS data generating model parameters

px_dist_o, px_args_o = 'Gaussian', {'mean': [-0.5], 'cov': [[1.5]]}
prop_fn_o, prop_args_o = 'sigmoid', {'beta': [0.8, 0.25]}
tte_params_o = {'model': 'coxph',
                'hazard': 'weibull',
                'cox_args': {'Y0': {'beta': [0,1], 'lambda': 0.8, 'p': 4},
                            'Y1': {'beta': [0,0.25], 'lambda': 0.45, 'p': 4},
                            'C0': {'beta': [0,0], 'lambda': 0.3, 'p': 4},
                            'C1': {'beta': [0,0], 'lambda': 0.3, 'p': 4},},
                }

RCTData = SyntheticDataModule(save_df, d, rct_size, 0, px_dist_r, px_args_r, prop_fn_r, prop_args_r, tte_params_r)
OBSData = SyntheticDataModule(save_df, d, obs_size, 1, px_dist_o, px_args_o, prop_fn_o, prop_args_o, tte_params_o)

df_rct_oracle, df_rct = RCTData.get_df()
df_obs_oracle, df_obs = OBSData.get_df()

RCTData.summary(plot=True)
OBSData.summary(plot=True)

In [None]:
B = 100  # num. samples to model the null distribution in every single run
num_exp = 50
m_arr = [2, 3, 4]  # multiplier to get the observational study size.
mmr_results = np.zeros((len(m_arr), num_exp))
mmr_pvals = np.zeros((len(m_arr), num_exp))


for m in m_arr:
    obs_size = rct_size * m 

    for n in range(num_exp):
        RCTData = SyntheticDataModule(save_df=False, d, rct_size, 0, px_dist_r, px_args_r, prop_fn_r, prop_args_r, tte_params_r)
        OBSData = SyntheticDataModule(save_df, d, obs_size, 1, px_dist_o, px_args_o, prop_fn_o, prop_args_o, tte_params_o)

        df_rct_oracle, df_rct = RCTData.get_df()
        df_obs_oracle, df_obs = OBSData.get_df()

        df_combined = pd.concat([df_rct, df_obs], axis=0, ignore_index=True)  # merge the dataframes into one
        cov_list = RCTData.get_covs()

        # Estimate the nuisance parameters

        df_combined['P(S=1|X)'] = prop_score_est(df_combined.copy(), 'S', cov_list, 'logistic')

        df_combined.loc[df_combined.S==0, 'P(A=1|X,S)'] = prop_score_est(df_combined.query('S==0').copy(), 'A', cov_list, 'logistic')
        df_combined.loc[df_combined.S==1, 'P(A=1|X,S)'] = prop_score_est(df_combined.query('S==1').copy(), 'A', cov_list, 'logistic')

        gc_est(df_combined, cov_list, tte_model='coxph')

        ipcw_est(df_combined, S=0)
        ipcw_est(df_combined, S=1)

        mmr_results[m,n], mmr_pvals[m,n] = mmr_test(df_combined, cov_list, B=B, kernel=rbf_kernel, signal0='ipcw_est_S0', signal1='ipcw_est_S1')

In [None]:
x_space = np.linspace(-10,10,401)
cov_name = 'X1'
obs_oracle_prop = OBSData.calc_oracle_prop(x_space, cov_name)
plt.figure()
plt.plot(x_space, obs_oracle_prop)
plt.xlabel(cov_name)
plt.ylabel(f'P(A=1|{cov_name},S=1)')
plt.title(f'Oracle propensity score in study S=1 wrt covariate {cov_name}')
plt.show()

t = np.linspace(0,10,101)
cov_vals = [0, 0]
tbs_Y0 = RCTData.get_oracle_surv_curve(t, cov_vals, 'Y0')
tbs_Y1 = RCTData.get_oracle_surv_curve(t, cov_vals, 'Y1')
tbs_C0 = RCTData.get_oracle_surv_curve(t, cov_vals, 'C0')
tbs_C1 = RCTData.get_oracle_surv_curve(t, cov_vals, 'C1')
plt.figure()
plt.plot(t, tbs_Y0, label='Y0')
plt.plot(t, tbs_Y1, label='Y1')
plt.plot(t, tbs_C0, label='C0')
plt.plot(t, tbs_C1, label='C1')
plt.xlabel('t')
plt.ylabel(r'$S(t)$')
plt.title(f'True survival curves in study S=0 with X={cov_vals}')
plt.legend()
plt.show()