In [None]:
## Reference: https://www.youtube.com/watch?v=xekqR10lQNo
## For personal study purposes only

In [3]:
from dowhy import datasets

import pandas as pd
import numpy as np

from causalinference import CausalModel

## Create Dataset

In [14]:
data = datasets.linear_dataset(
    beta=10,
    num_common_causes=4,
    num_samples=10_000,
    treatment_is_binary=True,
    outcome_is_binary=False,
)

df = data['df']
df = df.rename({'v0': 'treatment', 'y': 'outcome'}, axis=1)
df['treatment'] = df['treatment'].astype(int)

df.head()

Unnamed: 0,W0,W1,W2,W3,treatment,outcome
0,0.314915,-0.094394,-2.343535,-0.604731,0,-7.852001
1,-0.509568,0.029425,0.428977,0.69664,1,10.346186
2,1.129318,0.541612,-0.579068,0.002772,1,15.421845
3,1.289633,-1.115158,0.233141,-1.93935,0,0.404312
4,2.288021,-1.096157,-0.198825,-0.496505,1,17.25201


## Raw Difference

In [16]:
causal = CausalModel(
    Y=df['outcome'].values,
    D=df['treatment'].values,
    X=df[['W0', 'W1', 'W2', 'W3']].values,
)

print(causal.summary_stats)


Summary Statistics

                      Controls (N_c=5591)        Treated (N_t=4409)             
       Variable         Mean         S.d.         Mean         S.d.     Raw-diff
--------------------------------------------------------------------------------
              Y       -5.262        5.185       14.468        5.132       19.731

                      Controls (N_c=5591)        Treated (N_t=4409)             
       Variable         Mean         S.d.         Mean         S.d.     Nor-diff
--------------------------------------------------------------------------------
             X0        0.168        0.876        1.183        0.861        1.169
             X1       -0.265        0.990       -0.036        1.008        0.229
             X2       -1.156        0.946       -0.650        0.968        0.529
             X3       -0.705        0.887        0.280        0.885        1.112



## Treatment Effect using OLS

In [19]:
causal.est_via_ols(adj=0)
print('adj=0', causal.estimates)


causal.est_via_ols(adj=1)
print('adj=1', causal.estimates)

causal.est_via_ols(adj=2)
print('adj=2', causal.estimates)

adj=0 
Treatment Effect Estimates: OLS

                     Est.       S.e.          z      P>|z|      [95% Conf. int.]
--------------------------------------------------------------------------------
           ATE     19.731      0.104    190.038      0.000     19.527     19.934

adj=1 
Treatment Effect Estimates: OLS

                     Est.       S.e.          z      P>|z|      [95% Conf. int.]
--------------------------------------------------------------------------------
           ATE     10.000      0.000  32276.321      0.000      9.999     10.001

adj=2 
Treatment Effect Estimates: OLS

                     Est.       S.e.          z      P>|z|      [95% Conf. int.]
--------------------------------------------------------------------------------
           ATE     10.000      0.000  31816.835      0.000      9.999     10.000
           ATC      9.999      0.000  24594.510      0.000      9.999     10.000
           ATT     10.000      0.000  26664.389      0.000     10.00