## Semi-synthetic data setup (IHDP) 

In [1]:
import pandas as pd 
import numpy as np

In [2]:
# Load data 
ihdp_table = pd.read_csv('ihdp.csv')
ihdp_full_table = pd.read_csv('ihdpFull.csv')
display(ihdp_full_table)

Unnamed: 0,sitenum,momrace,bmarr,workdur,cccats3,treat,prenatal,first,bw,preterm,...,other60,mom248,relative248,nanny248,familydc248,center248,other248,ccurr60,nvisitt,ncdct
0,1,3,1,1,3,1,1,1,1559,10,...,0,0,0,0,0,1,0,1,72,432
1,1,1,0,1,1,1,0,1,1420,4,...,0,0,1,0,0,0,0,1,60,391
2,1,1,0,1,2,0,1,1,1000,8,...,0,0,1,0,0,0,0,1,0,0
3,1,1,0,0,1,0,1,2,1430,6,...,0,0,0,0,0,1,0,1,0,0
4,1,1,0,0,2,0,1,2,1984,2,...,0,0,1,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
980,8,3,1,1,2,1,1,2,2140,5,...,0,0,0,0,0,1,0,1,56,188
981,8,1,0,1,2,0,1,1,2350,3,...,0,0,1,0,0,0,0,1,0,0
982,8,3,1,1,3,0,1,2,1670,3,...,0,0,1,0,0,0,0,1,0,0
983,8,3,1,1,1,0,1,2,1740,6,...,0,0,1,0,0,0,0,1,0,0


In [3]:
ihdp_table

Unnamed: 0,treat,bw,b.head,preterm,birth.o,nnhealth,momage,sex,twin,b.marr,...,ark,ein,har,mia,pen,tex,was,momwhite,momblack,momhisp
0,1,1559,28.648521,10,2,94,33,1,0,1,...,1,0,0,0,0,0,0,1,0,0
1,1,1420,27.000000,4,2,85,15,1,0,0,...,1,0,0,0,0,0,0,0,1,0
2,0,1000,25.000000,8,4,89,33,0,0,0,...,1,0,0,0,0,0,0,0,1,0
3,0,1430,29.000000,6,1,112,22,0,0,0,...,1,0,0,0,0,0,0,0,1,0
4,0,1984,31.000000,2,1,99,20,0,0,0,...,1,0,0,0,0,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
980,1,2140,29.000000,5,1,112,32,0,0,1,...,0,0,0,0,0,0,0,1,0,0
981,0,2350,30.000000,3,2,111,28,0,0,0,...,0,0,0,0,0,0,0,0,1,0
982,0,1670,29.000000,3,1,125,28,1,0,1,...,0,0,0,0,0,0,0,1,0,0
983,0,1740,31.000000,6,1,107,26,0,0,1,...,0,0,0,0,0,0,0,1,0,0


## Simulate Outcomes & Setup RCT 

Note that the covariates that we will use to determine our strata are birthweight (cutoff: 2000g) and marital status. 

$d = \{(\text{low birthweight}, \text{married}), (\text{high birthweight}, \text{married})\}$

$d' =\{(\text{low birthweight}, \text{single}), (\text{high birthweight}, \text{single})\}$ 

The second set contain strata that are not covered by the RCT.

In [4]:
ihdp_table.columns.values[1:].shape

(28,)

In [5]:
ihdp_table.columns.values

array(['treat', 'bw', 'b.head', 'preterm', 'birth.o', 'nnhealth',
       'momage', 'sex', 'twin', 'b.marr', 'mom.lths', 'mom.hs',
       'mom.scoll', 'cig', 'first', 'booze', 'drugs', 'work.dur',
       'prenatal', 'ark', 'ein', 'har', 'mia', 'pen', 'tex', 'was',
       'momwhite', 'momblack', 'momhisp'], dtype=object)

In [6]:
## 1 round 
# simulate outcomes
np.random.seed(4)
ihdp_table = pd.read_csv('ihdp.csv')
num_covariates = 28 # including ethnicity 
coefs = np.array([0,0.1,0.2,0.3,0.4])
probs = np.array([0.6,0.1,0.1,0.1,0.1])
omega = -23

# normalization 
cont = ihdp_table.iloc[:,1:7]
ihdp_table.iloc[:,1:7] = (cont - cont.mean())/cont.std() # Normalize continuous variables

# simulation
beta_B = np.random.choice(coefs, size=[num_covariates,1], replace=True, p=probs)
print(ihdp_table.columns.values[1:])
print(f'beta B: {beta_B.squeeze()}')
W = np.zeros(num_covariates)+0.5
y0 = []; y1 = []
y0_noise = []; y1_noise = []
for idx, row in ihdp_table.iterrows(): 
    X = row.values[1:]
    mean0 = np.exp(np.matmul((X+W)[None,:],beta_B))[0]
    mean1 = np.matmul(X[None,:],beta_B)[0] - omega
    y0_noise.append(np.random.normal(mean0,1)[0])
    y1_noise.append(np.random.normal(mean1,1)[0])
    y0.append(mean0[0])
    y1.append(mean1[0])
    
# inserting potential columns
ihdp_table.insert(loc=0, column='y0_noise', value=y0_noise)
ihdp_table.insert(loc=0, column='y1_noise', value=y1_noise)
ihdp_table.insert(loc=0, column='y0', value=y0)
ihdp_table.insert(loc=0, column='y1', value=y1)
display(ihdp_table)

['bw' 'b.head' 'preterm' 'birth.o' 'nnhealth' 'momage' 'sex' 'twin'
 'b.marr' 'mom.lths' 'mom.hs' 'mom.scoll' 'cig' 'first' 'booze' 'drugs'
 'work.dur' 'prenatal' 'ark' 'ein' 'har' 'mia' 'pen' 'tex' 'was'
 'momwhite' 'momblack' 'momhisp']
beta B: [0.4 0.  0.4 0.2 0.1 0.  0.4 0.  0.  0.  0.2 0.  0.3 0.4 0.  0.  0.  0.
 0.  0.4 0.  0.4 0.2 0.3 0.  0.  0.1 0. ]


Unnamed: 0,y1,y0,y1_noise,y0_noise,treat,bw,b.head,preterm,birth.o,nnhealth,...,ark,ein,har,mia,pen,tex,was,momwhite,momblack,momhisp
0,24.026369,18.659756,25.573348,18.052869,1,-0.518060,-0.315951,1.126525,0.103688,-0.377548,...,1,0,0,0,0,0,0,1,0,0
1,23.052320,7.045016,23.098456,7.768358,1,-0.822071,-0.984367,-1.112884,0.103688,-0.944350,...,1,0,0,0,0,0,0,0,1,0
2,23.607771,12.277532,23.662204,11.294540,0,-1.740668,-1.795294,0.380055,2.106298,-0.692438,...,1,0,0,0,0,0,0,0,1,0
3,23.329436,9.294627,22.120488,9.454520,0,-0.800200,-0.173440,-0.366414,-0.897616,0.756055,...,1,0,0,0,0,0,0,0,1,0
4,23.435058,10.330059,23.829353,12.553420,0,0.411472,0.637487,-1.859354,-0.897616,-0.062658,...,1,0,0,0,0,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
980,23.701288,13.481097,24.273572,14.611561,1,0.752665,-0.173440,-0.739649,-0.897616,0.756055,...,0,0,0,0,0,0,0,1,0,0
981,23.680383,13.202192,23.006319,12.393562,0,1.211963,0.232024,-1.486119,0.103688,0.693077,...,0,0,0,0,0,0,0,0,1,0
982,23.473391,10.733726,22.977504,10.130477,0,-0.275288,-0.173440,-1.486119,-0.897616,1.574769,...,0,0,0,0,0,0,0,1,0,0
983,23.969152,17.622070,24.035466,17.349401,0,-0.122189,0.637487,-0.366414,-0.897616,0.441165,...,0,0,0,0,0,0,0,1,0,0


In [7]:
mean_bw = cont['bw'].mean()
std_bw  = cont['bw'].std()
bw_norm = (2000 - mean_bw) / std_bw

In [8]:
# split into four subpopulations
sg1 = ihdp_table[(ihdp_table['bw']<bw_norm) & (ihdp_table['b.marr']==1.)]
sg2 = ihdp_table[(ihdp_table['bw']>=bw_norm) & (ihdp_table['b.marr']==1.)]
sg3 = ihdp_table[(ihdp_table['bw']<bw_norm) & (ihdp_table['b.marr']==0.)]
sg4 = ihdp_table[(ihdp_table['bw']>=bw_norm) & (ihdp_table['b.marr']==0.)]

sgs = [sg1,sg2,sg3,sg4]
sg_names = ['low-bw,married', 'high-bw,married', 'low-bw,single', 'high-bw,married']
for i,sg in enumerate(sgs): 
    # compute the CATE 
    cate = (sg['y1'] - sg['y0']).mean()
    print(f'CATE for subgroup {i+1} ({sg_names[i]}): {cate}')


CATE for subgroup 1 (low-bw,married): 4.190872951506137
CATE for subgroup 2 (high-bw,married): 1.3253262732258888
CATE for subgroup 3 (low-bw,single): 0.4243580884750287
CATE for subgroup 4 (high-bw,married): -2.8974668589645294


In [9]:
rct_table = ihdp_table[ihdp_table['b.marr'] == 1.]
rct_table

Unnamed: 0,y1,y0,y1_noise,y0_noise,treat,bw,b.head,preterm,birth.o,nnhealth,...,ark,ein,har,mia,pen,tex,was,momwhite,momblack,momhisp
0,24.026369,18.659756,25.573348,18.052869,1,-0.518060,-0.315951,1.126525,0.103688,-0.377548,...,1,0,0,0,0,0,0,1,0,0
9,24.524207,30.698280,25.362043,29.592696,0,0.536139,0.637487,0.753290,1.104993,-0.125636,...,1,0,0,0,0,0,0,1,0,0
10,24.117783,20.445911,25.032624,22.533782,0,-1.040785,-1.323942,1.126525,-0.897616,0.630099,...,1,0,0,0,0,0,0,1,0,0
11,24.351702,25.834271,25.148214,25.558067,0,0.774536,1.042951,0.006821,1.104993,-0.818394,...,1,0,0,0,0,0,0,0,1,0
13,23.893691,16.341218,23.884331,14.993758,1,-0.537744,-0.173440,0.380055,-0.897616,-1.637107,...,1,0,0,0,0,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
979,23.540033,11.473418,22.569308,10.106831,0,0.009039,-0.173440,-0.366414,0.103688,-0.377548,...,0,0,0,0,0,0,0,1,0,0
980,23.701288,13.481097,24.273572,14.611561,1,0.752665,-0.173440,-0.739649,-0.897616,0.756055,...,0,0,0,0,0,0,0,1,0,0
982,23.473391,10.733726,22.977504,10.130477,0,-0.275288,-0.173440,-1.486119,-0.897616,1.574769,...,0,0,0,0,0,0,0,1,0,0
983,23.969152,17.622070,24.035466,17.349401,0,-0.122189,0.637487,-0.366414,-0.897616,0.441165,...,0,0,0,0,0,0,0,1,0,0


In [None]:
# E[Y=1 | X\in low-birthweight,married]-E[Y=0 | X\inlow-birthweight,married]

## Reconstructing RCT & Constructing OBS Studies

How to introduce confounding? Ans: add confounding variables? add samples that have non-random treatment given to them? 

cvar1 (binary) -  ((y1 / 2) * treat > threshold) 
cvar2 

5 confounders case: X' 
to original functions, y1 and y0: add X'*\gamma (not an effect modifier!!) 
for RCT - X'[i] is 1 50% of the time and 0 50% of the time for treatment group (same procedure for control group)
for OBS - X'[i] is 1 75% of the time and 0 25% of the time for treatment group (same procedure for control group)

In [55]:
def get_confounding_variables(probs=[0.5,0.5], binary=True, num_samples=50): 
    Xprime = np.random.choice([0,1],size=num_samples,p=probs)
    if binary: 
        return Xprime
    normal1 = np.random.normal(0,1,size=num_samples)
    normal2 = np.random.normal(3,1,size=num_samples)
    return Xprime*normal1 + (1-Xprime)*normal2

In [71]:
## 1 round 
ihdp_table = pd.read_csv('ihdp.csv')
ihdp_trt = ihdp_table[ihdp_table['treat']==1]
ihdp_control = ihdp_table[ihdp_table['treat']==0]

# get confounding variable
np.random.seed(4)
num_continuous = 4
num_binary = 3 

Ntrt = ihdp_trt['treat'].values.shape[0]
Nctr = ihdp_control['treat'].values.shape[0]

Xprime_rct_trts = np.zeros((Ntrt,num_continuous+num_binary))
Xprime_rct_controls = np.zeros((Nctr,num_continuous+num_binary))
Xprime_obs_trts = np.zeros((Ntrt,num_continuous+num_binary))
Xprime_obs_controls = np.zeros((Nctr,num_continuous+num_binary))

t1 = get_confounding_variables(probs=[0.5,0.5], binary=False, num_samples=Ntrt*num_continuous)
c1 = get_confounding_variables(probs=[0.5,0.5], binary=False, num_samples=Nctr*num_continuous)
t2 = get_confounding_variables(probs=[0.25,0.75], binary=False, num_samples=Ntrt*num_continuous)
c2 = get_confounding_variables(probs=[0.75,0.25], binary=False, num_samples=Nctr*num_continuous)

Xprime_rct_trts[:,:num_continuous] = np.reshape(t1, (Ntrt,num_continuous))
Xprime_rct_controls[:,:num_continuous] = np.reshape(c1, (Nctr,num_continuous))
Xprime_obs_trts[:,:num_continuous] = np.reshape(t2, (Ntrt,num_continuous))
Xprime_obs_controls[:,:num_continuous] = np.reshape(c2, (Nctr,num_continuous))

t1 = get_confounding_variables(probs=[0.5,0.5], binary=True, num_samples=Ntrt*num_binary)
c1 = get_confounding_variables(probs=[0.5,0.5], binary=True, num_samples=Nctr*num_binary)
t2 = get_confounding_variables(probs=[0.25,0.75], binary=True, num_samples=Ntrt*num_binary)
c2 = get_confounding_variables(probs=[0.75,0.25], binary=True, num_samples=Nctr*num_binary)

Xprime_rct_trts[:,num_continuous:] = np.reshape(t1, (Ntrt,num_binary))
Xprime_rct_controls[:,num_continuous:] = np.reshape(c1, (Nctr,num_binary))
Xprime_obs_trts[:,num_continuous:] = np.reshape(t2, (Ntrt,num_binary))
Xprime_obs_controls[:,num_continuous:] = np.reshape(c2, (Nctr,num_binary))

print(f'Xprime_rct_trts shape: {Xprime_rct_trts.shape}')
print(Xprime_rct_trts[:,:num_continuous])
print(Xprime_rct_trts[:,num_continuous:])


Xprime_rct_trts shape: (377, 7)
[[ 1.25660587  0.78097266 -2.2313993   0.84124181]
 [-1.26441476  3.48581635 -0.32451439  3.34133103]
 [ 3.18055307  3.02144991  0.8860924   3.70711163]
 ...
 [-2.93342082 -0.64815152 -0.19028483 -1.33060977]
 [-0.89429979  3.80058072 -0.39887409  0.62835001]
 [ 1.78323276  2.33722472 -0.44506429 -1.2140657 ]]
[[1. 1. 0.]
 [0. 0. 1.]
 [1. 0. 1.]
 ...
 [1. 1. 1.]
 [1. 0. 1.]
 [0. 1. 0.]]


In [72]:
# append 
Xprime_rct_trts_df = pd.DataFrame(Xprime_rct_trts, columns=[f'xprime_rct{i+1}' for i in range(num_continuous+num_binary)])
Xprime_rct_ctrs_df = pd.DataFrame(Xprime_rct_controls, columns=[f'xprime_rct{i+1}' for i in range(num_continuous+num_binary)])
Xprime_obs_trts_df = pd.DataFrame(Xprime_obs_trts, columns=[f'xprime_obs{i+1}' for i in range(num_continuous+num_binary)])
Xprime_obs_ctrs_df = pd.DataFrame(Xprime_obs_controls, columns=[f'xprime_obs{i+1}' for i in range(num_continuous+num_binary)])

ihdp_trt = pd.concat([ihdp_trt.reset_index(drop=True), Xprime_rct_trts_df, Xprime_obs_trts_df], axis=1, sort=False)
ihdp_control = pd.concat([ihdp_control.reset_index(drop=True), Xprime_rct_ctrs_df, Xprime_obs_ctrs_df], axis=1, sort=False)
ihdp_table = pd.concat([ihdp_trt, ihdp_control], ignore_index=True, sort=False)
display(ihdp_table)

Unnamed: 0,treat,bw,b.head,preterm,birth.o,nnhealth,momage,sex,twin,b.marr,...,xprime_rct5,xprime_rct6,xprime_rct7,xprime_obs1,xprime_obs2,xprime_obs3,xprime_obs4,xprime_obs5,xprime_obs6,xprime_obs7
0,1,1559,28.648521,10,2,94,33,1,0,1,...,1.0,1.0,0.0,1.031226,-1.068553,-0.605096,-1.938174,0.0,1.0,0.0
1,1,1420,27.000000,4,2,85,15,1,0,0,...,0.0,0.0,1.0,3.476216,-1.635055,1.666710,-1.249143,1.0,0.0,0.0
2,1,2240,31.000000,3,2,105,22,1,0,0,...,1.0,0.0,1.0,-0.131051,1.726972,0.068636,-0.270446,1.0,1.0,0.0
3,1,1900,30.000000,6,1,110,13,1,0,0,...,1.0,0.0,0.0,-0.046200,-0.491871,0.309800,3.491736,1.0,1.0,0.0
4,1,1550,29.000000,8,1,74,25,1,0,1,...,0.0,0.0,1.0,-0.288896,1.182057,0.311918,3.590592,1.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
980,0,1650,31.000000,6,2,105,34,1,1,1,...,1.0,1.0,1.0,1.490505,2.139331,2.959090,0.130328,0.0,0.0,0.0
981,0,1800,29.000000,6,2,94,36,0,0,1,...,1.0,1.0,0.0,2.543749,-0.796368,-0.426383,3.148424,0.0,1.0,0.0
982,0,2350,30.000000,3,2,111,28,0,0,0,...,0.0,0.0,1.0,3.382456,2.301142,4.363305,0.493031,0.0,0.0,0.0
983,0,1670,29.000000,3,1,125,28,1,0,1,...,1.0,0.0,1.0,3.631826,2.305167,-0.994249,2.752103,1.0,0.0,1.0


In [73]:
# simulate outcomes
num_covariates = 28 # including ethnicity 
coefs = np.array([0,0.1,0.2,0.3,0.4])
probs = np.array([0.6,0.1,0.1,0.1,0.1])
omega = -23

gamma_coefs = [0.5,1.5,2.5,3.5,4.5]
gamma_probs = [0.2,0.2,0.2,0.2,0.2]
gamma = np.random.choice(gamma_coefs, size=[num_continuous+num_binary,1], replace=True, p=gamma_probs)

# normalization 
cont = ihdp_table.iloc[:,1:7]
ihdp_table.iloc[:,1:7] = (cont - cont.mean())/cont.std() # Normalize continuous variables

# simulation
np.random.seed(4)
beta_B = np.random.choice(coefs, size=[num_covariates,1], replace=True, p=probs)
print(beta_B.squeeze())
W = np.zeros(num_covariates)+0.5
y0_rct = []; y1_rct = []
y0_obs = []; y1_obs = []
y_rct  = []; y_obs  = []
for idx, row in ihdp_table.iterrows(): 
    X = row.values[1:-((num_continuous+num_binary)*2)]
    assert X.shape[0] == num_covariates
    mean0_rct = np.exp(np.matmul((X+W)[None,:],beta_B))[0][0] \
        + np.matmul(row.loc['xprime_rct1':'xprime_rct7'].values[None,:],gamma)[0][0]
    mean1_rct = np.matmul(X[None,:],beta_B)[0][0] - omega \
        + np.matmul(row.loc['xprime_rct1':'xprime_rct7'].values[None,:],gamma)[0][0]
    mean0_obs = np.exp(np.matmul((X+W)[None,:],beta_B))[0][0] \
        + np.matmul(row.loc['xprime_obs1':'xprime_obs7'].values[None,:],gamma)[0][0]
    mean1_obs = np.matmul(X[None,:],beta_B)[0][0] - omega \
        + np.matmul(row.loc['xprime_obs1':'xprime_obs7'].values[None,:],gamma)[0][0]
    y0_rct.append(mean0_rct); y1_rct.append(mean1_rct)
    y0_obs.append(mean0_obs); y1_obs.append(mean1_obs)
    
    y_rct_noise = np.random.normal(row['treat']*mean1_rct + (1-row['treat'])*mean0_rct, scale=1.)
    y_obs_noise = np.random.normal(row['treat']*mean1_obs + (1-row['treat'])*mean0_obs, scale=1.)
    y_rct.append(y_rct_noise); y_obs.append(y_obs_noise)
    
# inserting potential columns
ihdp_table.insert(loc=0, column='y0_rct', value=y0_rct)
ihdp_table.insert(loc=0, column='y1_rct', value=y1_rct)
ihdp_table.insert(loc=0, column='y0_obs', value=y0_obs)
ihdp_table.insert(loc=0, column='y1_obs', value=y1_obs)
ihdp_table.insert(loc=0, column='y_obs', value=y_obs)
ihdp_table.insert(loc=0, column='y_rct', value=y_rct)
display(ihdp_table)

[0.4 0.  0.4 0.2 0.1 0.  0.4 0.  0.  0.  0.2 0.  0.3 0.4 0.  0.  0.  0.
 0.  0.4 0.  0.4 0.2 0.3 0.  0.  0.1 0. ]


Unnamed: 0,y_rct,y_obs,y1_obs,y0_obs,y1_rct,y0_rct,treat,bw,b.head,preterm,...,xprime_rct5,xprime_rct6,xprime_rct7,xprime_obs1,xprime_obs2,xprime_obs3,xprime_obs4,xprime_obs5,xprime_obs6,xprime_obs7
0,32.722376,23.159024,21.612044,16.245431,33.329264,27.962650,1,-0.518060,-0.315951,1.126525,...,1.0,1.0,0.0,1.031226,-1.068553,-0.605096,-1.938174,0.0,1.0,0.0
1,38.311378,36.145087,36.098951,20.091647,37.588036,21.580732,1,-0.822071,-0.984367,-1.112884,...,0.0,0.0,1.0,3.476216,-1.635055,1.666710,-1.249143,1.0,0.0,0.0
2,60.540740,31.952752,31.898319,22.254600,61.523732,51.880012,1,0.971378,0.637487,-1.486119,...,1.0,0.0,1.0,-0.131051,1.726972,0.068636,-0.270446,1.0,1.0,0.0
3,33.213622,29.691932,30.900880,27.429193,33.053729,29.582041,1,0.227753,0.232024,-0.366414,...,1.0,0.0,0.0,-0.046200,-0.491871,0.309800,3.491736,1.0,1.0,0.0
4,59.449926,35.978917,35.584622,28.032149,57.226566,49.674093,1,-0.537744,-0.173440,0.380055,...,0.0,0.0,1.0,-0.288896,1.182057,0.311918,3.590592,1.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
980,22.569941,33.958471,42.607142,33.386187,30.660433,21.439478,0,-0.319031,0.637487,-0.366414,...,1.0,1.0,1.0,1.490505,2.139331,2.959090,0.130328,0.0,0.0,0.0
981,28.950211,24.041998,36.782678,24.716063,41.825456,29.758841,0,0.009039,-0.173440,-0.366414,...,1.0,1.0,0.0,2.543749,-0.796368,-0.426383,3.148424,0.0,1.0,0.0
982,16.672115,43.265856,54.239933,43.761743,27.753554,17.275364,0,1.211963,0.232024,-1.486119,...,0.0,0.0,1.0,3.382456,2.301142,4.363305,0.493031,0.0,0.0,0.0
983,49.129798,43.848122,56.521473,43.781808,62.142132,49.402467,0,-0.275288,-0.173440,-1.486119,...,1.0,0.0,1.0,3.631826,2.305167,-0.994249,2.752103,1.0,0.0,1.0


In [74]:
# crude CATEs + stratified CATEs
columns_to_drop = [f'xprime_rct{i+1}' for i in range(num_continuous+num_binary)]
obs_table = ihdp_table.drop(columns=['y1_rct','y0_rct']+columns_to_drop,inplace=False) # remove y1_rct, y0_rct, xprime_rct
columns_to_drop = [f'xprime_obs{i+1}' for i in range(num_continuous+num_binary)]
rct_table = ihdp_table.drop(columns=['y1_obs','y0_obs']+columns_to_drop,inplace=False) # remove y1_obs, y0_obs, xprime_obs
print('OBS data')
display(obs_table)
print('RCT data')
display(rct_table)

OBS data


Unnamed: 0,y_rct,y_obs,y1_obs,y0_obs,treat,bw,b.head,preterm,birth.o,nnhealth,...,momwhite,momblack,momhisp,xprime_obs1,xprime_obs2,xprime_obs3,xprime_obs4,xprime_obs5,xprime_obs6,xprime_obs7
0,32.722376,23.159024,21.612044,16.245431,1,-0.518060,-0.315951,1.126525,0.103688,-0.377548,...,1,0,0,1.031226,-1.068553,-0.605096,-1.938174,0.0,1.0,0.0
1,38.311378,36.145087,36.098951,20.091647,1,-0.822071,-0.984367,-1.112884,0.103688,-0.944350,...,0,1,0,3.476216,-1.635055,1.666710,-1.249143,1.0,0.0,0.0
2,60.540740,31.952752,31.898319,22.254600,1,0.971378,0.637487,-1.486119,0.103688,0.315209,...,0,1,0,-0.131051,1.726972,0.068636,-0.270446,1.0,1.0,0.0
3,33.213622,29.691932,30.900880,27.429193,1,0.227753,0.232024,-0.366414,-0.897616,0.630099,...,0,1,0,-0.046200,-0.491871,0.309800,3.491736,1.0,1.0,0.0
4,59.449926,35.978917,35.584622,28.032149,1,-0.537744,-0.173440,0.380055,-0.897616,-1.637107,...,0,1,0,-0.288896,1.182057,0.311918,3.590592,1.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
980,22.569941,33.958471,42.607142,33.386187,0,-0.319031,0.637487,-0.366414,0.103688,0.315209,...,1,0,0,1.490505,2.139331,2.959090,0.130328,0.0,0.0,0.0
981,28.950211,24.041998,36.782678,24.716063,0,0.009039,-0.173440,-0.366414,0.103688,-0.377548,...,1,0,0,2.543749,-0.796368,-0.426383,3.148424,0.0,1.0,0.0
982,16.672115,43.265856,54.239933,43.761743,0,1.211963,0.232024,-1.486119,0.103688,0.693077,...,0,1,0,3.382456,2.301142,4.363305,0.493031,0.0,0.0,0.0
983,49.129798,43.848122,56.521473,43.781808,0,-0.275288,-0.173440,-1.486119,-0.897616,1.574769,...,1,0,0,3.631826,2.305167,-0.994249,2.752103,1.0,0.0,1.0


RCT data


Unnamed: 0,y_rct,y_obs,y1_rct,y0_rct,treat,bw,b.head,preterm,birth.o,nnhealth,...,momwhite,momblack,momhisp,xprime_rct1,xprime_rct2,xprime_rct3,xprime_rct4,xprime_rct5,xprime_rct6,xprime_rct7
0,32.722376,23.159024,33.329264,27.962650,1,-0.518060,-0.315951,1.126525,0.103688,-0.377548,...,1,0,0,1.256606,0.780973,-2.231399,0.841242,1.0,1.0,0.0
1,38.311378,36.145087,37.588036,21.580732,1,-0.822071,-0.984367,-1.112884,0.103688,-0.944350,...,0,1,0,-1.264415,3.485816,-0.324514,3.341331,0.0,0.0,1.0
2,60.540740,31.952752,61.523732,51.880012,1,0.971378,0.637487,-1.486119,0.103688,0.315209,...,0,1,0,3.180553,3.021450,0.886092,3.707112,1.0,0.0,1.0
3,33.213622,29.691932,33.053729,29.582041,1,0.227753,0.232024,-0.366414,-0.897616,0.630099,...,0,1,0,0.939602,-0.632201,3.371046,-0.430913,1.0,0.0,0.0
4,59.449926,35.978917,57.226566,49.674093,1,-0.537744,-0.173440,0.380055,-0.897616,-1.637107,...,0,1,0,3.423506,1.615593,2.827099,3.021250,0.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
980,22.569941,33.958471,30.660433,21.439478,0,-0.319031,0.637487,-0.366414,0.103688,0.315209,...,1,0,0,0.243922,-0.438942,-1.026235,1.573569,1.0,1.0,1.0
981,28.950211,24.041998,41.825456,29.758841,0,0.009039,-0.173440,-0.366414,0.103688,-0.377548,...,1,0,0,0.858673,0.586323,1.338019,4.908156,1.0,1.0,0.0
982,16.672115,43.265856,27.753554,17.275364,0,1.211963,0.232024,-1.486119,0.103688,0.693077,...,0,1,0,-0.862547,-0.066519,0.458275,2.666693,0.0,0.0,1.0
983,49.129798,43.848122,62.142132,49.402467,0,-0.275288,-0.173440,-1.486119,-0.897616,1.574769,...,1,0,0,5.534432,-0.475747,2.786584,3.499358,1.0,0.0,1.0


CATE in RCT - 

CATE for subgroup 1 (low-bw,married): 4.190872951506137

CATE for subgroup 2 (high-bw,married): 1.3253262732258888

CATE for subgroup 3 (low-bw,single): 0.4243580884750287

CATE for subgroup 4 (high-bw,married): -2.8974668589645294

In [77]:
# CATEs in RCT 
# split into four subpopulations
mean_bw = cont['bw'].mean()
std_bw  = cont['bw'].std()
bw_norm = (2000 - mean_bw) / std_bw
print(bw_norm)
sg1 = rct_table[(rct_table['bw']<bw_norm) & (rct_table['b.marr']==1.)]
sg2 = rct_table[(rct_table['bw']>=bw_norm) & (rct_table['b.marr']==1.)]
sg3 = rct_table[(rct_table['bw']<bw_norm) & (rct_table['b.marr']==0.)]
sg4 = rct_table[(rct_table['bw']>=bw_norm) & (rct_table['b.marr']==0.)]

sgs = [sg1,sg2,sg3,sg4]
sg_names = ['low-bw,married', 'high-bw,married', 'low-bw,single', 'high-bw,single']
for i,sg in enumerate(sgs): 
    # compute the CATE
    print(f'true CATE: {sg["y1_rct"].mean() - sg["y0_rct"].mean()}')
    
    sg_trt = sg[sg['treat'] == 1]
    sg_ctr = sg[sg['treat'] == 0]
    cate = (sg_trt['y_rct'].mean() - sg_ctr['y_rct'].mean())
    print(f'noised CATE for subgroup {i+1} ({sg_names[i]}): {cate}')


0.4464661700041214
true CATE: 4.190872951506137
noised CATE for subgroup 1 (low-bw,married): 5.102056204276089
true CATE: 1.3253262732258904
noised CATE for subgroup 2 (high-bw,married): 4.474866560187003
true CATE: 0.4243580884750173
noised CATE for subgroup 3 (low-bw,single): 0.10413184147573418
true CATE: -2.8974668589645134
noised CATE for subgroup 4 (high-bw,single): -4.781357079996006


In [78]:
# CATEs in OBS
# split into four subpopulations
mean_bw = cont['bw'].mean()
std_bw  = cont['bw'].std()
bw_norm = (2000 - mean_bw) / std_bw
print(bw_norm)
sg1 = obs_table[(obs_table['bw']<bw_norm) & (obs_table['b.marr']==1.)]
sg2 = obs_table[(obs_table['bw']>=bw_norm) & (obs_table['b.marr']==1.)]
sg3 = obs_table[(obs_table['bw']<bw_norm) & (obs_table['b.marr']==0.)]
sg4 = obs_table[(obs_table['bw']>=bw_norm) & (obs_table['b.marr']==0.)]

sgs = [sg1,sg2,sg3,sg4]
sg_names = ['low-bw,married', 'high-bw,married', 'low-bw,single', 'high-bw,single']
for i,sg in enumerate(sgs): 
    # compute the CATE
    sg_trt = sg[sg['treat'] == 1]
    sg_ctr = sg[sg['treat'] == 0]
    cate = (sg_trt['y_obs'].mean() - sg_ctr['y_obs'].mean())
    print(f'noised CATE for subgroup {i+1} ({sg_names[i]}) in OBS data: {cate}')


0.4464661700041214
noised CATE for subgroup 1 (low-bw,married) in OBS data: -9.018581897418542
noised CATE for subgroup 2 (high-bw,married) in OBS data: -8.155859776340975
noised CATE for subgroup 3 (low-bw,single) in OBS data: -11.083398241592953
noised CATE for subgroup 4 (high-bw,single) in OBS data: -17.048845011862063


## Next Steps 

0. **communicate with mike**
1. **extend xprime to multiple confounding variables**
2. need to formalize + implement  
    - ML estimator (find packages for this) 
    - actual test (multivariate case) (\theta_1 - \theta_2)^T(\Sigma_1/n_1 + \Sigma_2/n_2)^-1 (\theta_1-\theta_2) \sim \chisquared {df dim}
3. exp 1: RCT, OBS1 - subset of confounders, OBS2 - all confounders [correct]
    - implement algorithm 
    - run exp 1.
    - baseline? do meta-analysis
    - metrics? coverage of the intervals, prob of selecting correct study?, prob of excluding incorrect study? 