# Benchmark MultiTreatmentIRM vs DoubleML Multitreatment (APOS)

Compare `MultiTreatmentIRM` from Causalis with the DoubleML multitreatment realization (`DoubleMLAPOS`) on the same 3-arm DGP.

# DGP

Use `generate_multitreatment_irm_26()` with oracle effects so we can benchmark each active arm against control.

In [1]:
from causalis.scenarios.multi_unconfoundedness.dgp import generate_multitreatment_irm_26

data = generate_multitreatment_irm_26(
    n=25_000,
    seed=42,
    include_oracle=True,
    return_causal_data=False,
)
data.head()

Unnamed: 0,y,d_0,d_1,d_2,tenure_months,avg_sessions_week,spend_last_month,premium_user,urban_resident,support_tickets_q,...,m_obs_d_1,tau_link_d_1,m_d_2,m_obs_d_2,tau_link_d_2,g_d_0,g_d_1,g_d_2,cate_d_1,cate_d_2
0,1.541272,1.0,0.0,0.0,27.656605,3.198667,89.609464,0.0,1.0,0.0,...,0.246062,-0.352005,0.220606,0.220606,0.494166,3.279384,2.306314,5.375338,-0.97307,2.095954
1,6.802333,1.0,0.0,0.0,23.798386,3.362415,102.337236,0.0,0.0,3.0,...,0.178897,-0.30736,0.236716,0.236716,0.420278,2.80785,2.064853,4.27463,-0.742997,1.46678
2,8.079449,1.0,0.0,0.0,28.425009,3.391819,102.660712,0.0,1.0,1.0,...,0.210001,-0.320189,0.21804,0.21804,0.502415,3.069919,2.228798,5.073677,-0.841121,2.003758
3,2.13682,1.0,0.0,0.0,18.860066,4.071175,83.593417,0.0,0.0,2.0,...,0.176239,-0.316241,0.237394,0.237394,0.441677,2.716805,1.980234,4.225485,-0.736571,1.50868
4,1.555391,0.0,1.0,0.0,17.853087,3.140075,79.20987,0.0,1.0,1.0,...,0.231904,-0.35013,0.246832,0.246832,0.493624,3.224354,2.271869,5.282273,-0.952485,2.057919


In [2]:
confounders = [
    "tenure_months",
    "avg_sessions_week",
    "spend_last_month",
    "premium_user",
    "urban_resident",
    "support_tickets_q",
    "discount_eligible",
    "credit_utilization",
]
treatment_cols = ["d_0", "d_1", "d_2"]
outcome_col = "y"

data[treatment_cols + [outcome_col]].mean().to_frame("mean")

Unnamed: 0,mean
d_0,0.50104
d_1,0.24572
d_2,0.25324
y,4.308808


In [3]:
oracle_ate = {
    "d_1 vs d_0": float(data["cate_d_1"].mean()),
    "d_2 vs d_0": float(data["cate_d_2"].mean()),
}
oracle_ate

{'d_1 vs d_0': -1.199206862416331, 'd_2 vs d_0': 2.5379024492441777}

In [4]:
from causalis.data_contracts.multicausaldata import MultiCausalData

causaldata = MultiCausalData(
    df=data,
    outcome=outcome_col,
    treatment_names=treatment_cols,
    confounders=confounders,
    control_treatment="d_0",
)
causaldata

MultiCausalData(df=(25000, 12), treatment_names=['d_0', 'd_1', 'd_2'], control_treatment='d_0')outcome='y', confounders=['tenure_months', 'avg_sessions_week', 'spend_last_month', 'premium_user', 'urban_resident', 'support_tickets_q', 'discount_eligible', 'credit_utilization'], user_id=None, 

# Causalis: MultiTreatmentIRM

In [5]:
from causalis.scenarios.multi_unconfoundedness import MultiTreatmentIRM

model = MultiTreatmentIRM(
    data=causaldata,
    n_folds=3,
    random_state=42,
).fit()

result_causalis = model.estimate(score="ATE")

In [6]:
result_causalis.summary()

Unnamed: 0_level_0,d_1 vs d_0,d_2 vs d_0
field,Unnamed: 1_level_1,Unnamed: 2_level_1
estimand,ATE,ATE
model,MultiTreatmentIRM,MultiTreatmentIRM
value,"-1.2818 (ci_abs: -1.3781, -1.1855)","2.3674 (ci_abs: 2.1994, 2.5354)"
value_relative,"-32.0434 (ci_rel: -34.0579, -30.0289)","59.1812 (ci_rel: 54.4420, 63.9205)"
alpha,0.0500,0.0500
p_value,0.0000,0.0000
is_significant,True,True
n_treated,6143,6331
n_control,12526,12526
treatment_mean,2.9329,6.6318


# DoubleML: multitreatment realization (APOS)

In [7]:
import doubleml as dml

if not hasattr(dml, "DoubleMLAPOS"):
    raise ImportError(
        "DoubleMLAPOS is required for this notebook. "
        "Please install a DoubleML version that includes DoubleMLAPOS."
    )

In [8]:
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from catboost import CatBoostClassifier, CatBoostRegressor
import numpy as np


class SkCatBoostRegressor(RegressorMixin, BaseEstimator):
    def __init__(self, **params):
        self.params = params
        self.model_ = CatBoostRegressor(**params)

    def fit(self, X, y, **fit_params):
        self.model_ = CatBoostRegressor(**self.params)
        self.model_.fit(X, y, verbose=False, **fit_params)
        try:
            self.n_features_in_ = X.shape[1]
        except Exception:
            pass
        return self

    def predict(self, X):
        return self.model_.predict(X)

    def get_params(self, deep=True):
        return dict(self.params)

    def set_params(self, **params):
        self.params.update(params)
        self.model_ = CatBoostRegressor(**self.params)
        return self


class SkCatBoostClassifier(ClassifierMixin, BaseEstimator):
    def __init__(self, **params):
        self.params = params
        self.model_ = CatBoostClassifier(**params)

    def fit(self, X, y, **fit_params):
        self.model_ = CatBoostClassifier(**self.params)
        self.model_.fit(X, y, verbose=False, **fit_params)
        if hasattr(self.model_, "classes_"):
            self.classes_ = self.model_.classes_
        else:
            self.classes_ = np.unique(y)
        try:
            self.n_features_in_ = X.shape[1]
        except Exception:
            pass
        return self

    def predict(self, X):
        return self.model_.predict(X)

    def predict_proba(self, X):
        proba = self.model_.predict_proba(X)
        if hasattr(self.model_, "classes_") and list(self.model_.classes_) != list(self.classes_):
            order = [list(self.model_.classes_).index(c) for c in self.classes_]
            proba = np.asarray(proba)[:, order]
        return proba

    def get_params(self, deep=True):
        return dict(self.params)

    def set_params(self, **params):
        self.params.update(params)
        self.model_ = CatBoostClassifier(**self.params)
        return self

In [9]:
boost = SkCatBoostRegressor(iterations=500, depth=6, learning_rate=0.1)
boost_class = SkCatBoostClassifier(
    iterations=500,
    depth=6,
    learning_rate=0.1,
    loss_function="MultiClass",
)

In [10]:
data_dml = data.copy()
data_dml["d_level"] = data_dml[treatment_cols].to_numpy().argmax(axis=1)

data_dml_base = dml.DoubleMLData(
    data_dml,
    y_col=outcome_col,
    d_cols="d_level",
    x_cols=confounders,
)
data_dml_base

<doubleml.data.base_data.DoubleMLData at 0x1692cde80>

In [13]:
# make sure treatment_levels contains all levels you want to compare
result_doubleml = dml.DoubleMLAPOS(
    data_dml_base,
    ml_g=boost,
    ml_m=boost_class,
    treatment_levels=[0, 1, 2],
    n_folds=3,
    n_rep=1,
    trimming_threshold=0.01,
)
result_doubleml.fit()

# API expects only reference_levels
contrast_doubleml = result_doubleml.causal_contrast(reference_levels=0)

# keep only desired targets (1 and 2) if needed
contrast_summary = contrast_doubleml.summary
contrast_summary = contrast_summary.loc[
    contrast_summary.index.astype(str).str.startswith(("1", "2"))
]

print(contrast_summary)

            coef   std err          t  P>|t|     2.5 %    97.5 %
1 vs 0 -1.319629  0.044631 -29.567224    0.0 -1.407105 -1.232153
2 vs 0  2.471794  0.079695  31.015554    0.0  2.315594  2.627994


In [15]:
import pandas as pd

dml_summary = contrast_doubleml.summary.copy()
if "coef" in dml_summary.columns:
    dml_values = dml_summary["coef"].to_numpy(dtype=float)
elif "theta" in dml_summary.columns:
    dml_values = dml_summary["theta"].to_numpy(dtype=float)
else:
    dml_values = np.asarray(getattr(contrast_doubleml, "all_thetas")).ravel().astype(float)

comparison = pd.DataFrame(
    {
        "contrast": result_causalis.contrast_labels,
        "oracle_ate": [oracle_ate["d_1 vs d_0"], oracle_ate["d_2 vs d_0"]],
        "causalis_multitreatment_irm": result_causalis.value,
        "doubleml_apos_contrast": dml_values,
    }
)
comparison["abs_diff_causalis_vs_dml"] = (
    comparison["causalis_multitreatment_irm"] - comparison["doubleml_apos_contrast"]
).abs()
comparison

Unnamed: 0,contrast,oracle_ate,causalis_multitreatment_irm,doubleml_apos_contrast,abs_diff_causalis_vs_dml
0,d_1 vs d_0,-1.199207,-1.281841,-1.319629,0.037788
1,d_2 vs d_0,2.537902,2.367443,2.471794,0.104351
