In [None]:
# Reference: https://docs.doubleml.org/stable/examples/py_double_ml_cate.html
# for personal study purposes only

In [3]:
import numpy as np
import pandas as pd
import doubleml as dml

from doubleml.datasets import make_heterogeneous_data


In [4]:
np.random.seed(42)
data_dict = make_heterogeneous_data(
    n_obs=2000,
    p=10,
    support_size=5,
    n_x=1,
    binary_treatment=True,
)
treatment_effect = data_dict['treatment_effect']
data = data_dict['data']
data.head()

Unnamed: 0,y,d,X_0,X_1,X_2,X_3,X_4,X_5,X_6,X_7,X_8,X_9
0,4.8033,1.0,0.259828,0.886086,0.89569,0.297287,0.229994,0.411304,0.240532,0.672384,0.826065,0.673092
1,5.655547,1.0,0.82435,0.396992,0.156317,0.737951,0.360475,0.671271,0.270644,0.08123,0.992582,0.156202
2,1.878402,0.0,0.988421,0.97728,0.793818,0.659423,0.577807,0.866102,0.28944,0.467681,0.61939,0.41119
3,6.94144,1.0,0.427486,0.330285,0.564232,0.850575,0.201528,0.934433,0.689088,0.823273,0.556191,0.779517
4,1.703049,1.0,0.0162,0.81838,0.040139,0.889913,0.991963,0.294067,0.210319,0.765363,0.253026,0.865562


In [6]:
data['d'].value_counts()

1.0    1348
0.0     652
Name: d, dtype: int64

In [7]:
data_dml_base = dml.DoubleMLData(
    data,
    y_col='y',
    d_cols='d',
)

In [8]:
# first stage estimation
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
rf_reg = RandomForestRegressor(n_estimators=500)
rf_class = RandomForestClassifier(n_estimators=500)

dml_irm = dml.DoubleMLIRM(
    data_dml_base,
    ml_g=rf_reg,
    ml_m=rf_class,
    trimming_threshold=0.05,
    n_folds=5,
)
print('Training IRM model')
dml_irm.fit()

print(dml_irm.summary)

Training IRM model
       coef   std err           t  P>|t|     2.5 %    97.5 %
d  4.463269  0.040742  109.549722    0.0  4.383417  4.543122


In [9]:
import patsy
design_matrix = patsy.dmatrix("bs(x, df=5, degree=2)", {"x": data['X_0']})
spline_basis = pd.DataFrame(design_matrix)

In [10]:
cate = dml_irm.cate(spline_basis)
print(cate)


------------------ Fit summary ------------------
       coef   std err          t          P>|t|    [0.025    0.975]
0  0.664965  0.160692   4.138127   3.647140e-05  0.349823  0.980108
1  2.368294  0.267523   8.852679   1.844643e-18  1.843641  2.892948
2  4.895260  0.171981  28.463894  7.630576e-150  4.557978  5.232543
3  4.782712  0.205983  23.219002  9.290475e-106  4.378748  5.186675
4  3.731584  0.209254  17.832803   4.072845e-66  3.321204  4.141963
5  4.328082  0.224902  19.244278   7.765736e-76  3.887014  4.769150
