In [15]:
import pandas as pd
import numpy as np

from econml.dr import LinearDRLearner
from xgboost import XGBRegressor, XGBClassifier

In [2]:
# Import the sample multi-attribution data
file_url = "https://msalicedatapublic.z5.web.core.windows.net/datasets/ROI/multi_attribution_sample.csv"
multi_data = pd.read_csv(file_url)

In [4]:
T_bin = multi_data[["Tech Support", "Discount"]]
Y = multi_data["Revenue"]
X = multi_data[["Size"]]
W = multi_data.drop(columns=["Tech Support", "Discount", "Revenue", "Size"])

In [6]:
multi_data[["Size", "Tech Support", "Discount"]].groupby(
    ["Tech Support", "Discount"], as_index=False).mean().astype(int)

Unnamed: 0,Tech Support,Discount,Size
0,0,0,70943
1,0,1,96466
2,1,0,108978
3,1,1,171466


In [10]:
def treat_map(t):
    return np.dot(t, 2 ** np.arange(t.shape[0]))


T = np.apply_along_axis(treat_map, 1, T_bin).astype(int)

In [20]:
model = LinearDRLearner(
    model_regression=XGBRegressor(learning_rate=.1, max_depth=3),
    model_propensity=XGBClassifier(learning_rate=.1, max_depth=3, objective="multi:softprob"),
    random_state=1
)

model.fit(Y=Y, T=T, W=W, X=X, inference="statsmodels")

<econml.dr._drlearner.LinearDRLearner at 0x17d9e5400>

# Understand Treatment Effects

In [21]:
model.summary(T=1)

0,1,2,3,4,5,6
,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
Size,0.021,0.012,1.749,0.08,-0.002,0.044

0,1,2,3,4,5,6
,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
cate_intercept,5326.611,845.551,6.3,0.0,3669.361,6983.861


In [23]:
for i in range(model._d_t[0]):
    display(model.summary(T=i + 1))

0,1,2,3,4,5,6
,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
Size,0.021,0.012,1.749,0.08,-0.002,0.044

0,1,2,3,4,5,6
,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
cate_intercept,5326.611,845.551,6.3,0.0,3669.361,6983.861


0,1,2,3,4,5,6
,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
Size,0.052,0.012,4.371,0.0,0.029,0.075

0,1,2,3,4,5,6
,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
cate_intercept,358.699,848.771,0.423,0.673,-1304.861,2022.258


0,1,2,3,4,5,6
,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
Size,0.074,0.012,6.292,0.0,0.051,0.096

0,1,2,3,4,5,6
,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
cate_intercept,4899.208,851.54,5.753,0.0,3230.22,6568.196
