In [None]:
import os
import sys

for path in ["../causalml", "../"]:
    module_path = os.path.abspath(os.path.join(path))
    if module_path not in sys.path:
        sys.path.append(module_path)

In [None]:
import numpy as np
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer, SimpleImputer
from tqdm.notebook import tqdm
from xgboost import XGBRegressor

from causalml.inference.meta import BaseDRLearner, BaseTLearner, BaseXRegressor
from causalml.propensity import GradientBoostedPropensityModel
from src.data.data_module import generate_data_exp
from src.data.utils import split_eval

In [None]:
def get_regressor(missing_value):
    return XGBRegressor(missing=missing_value, eval_metric="logloss")


learners = {
    "T": lambda missing_value: BaseTLearner(learner=get_regressor(missing_value)),
    "X": lambda missing_value: BaseXRegressor(learner=get_regressor(missing_value)),
    "DR": lambda missing_value: BaseDRLearner(
        learner=get_regressor(missing_value),
    ),
}


def xgb_prop(missing_value, X, W):
    pm = GradientBoostedPropensityModel(eval_metric="logloss", missing=missing_value)
    return pm.fit_predict(X, W)


prop_learners = {
    "none": lambda missing_value, X, W: None,
    "xgb": lambda missing_value, X, W: xgb_prop(missing_value, X, W),
}


def get_imputer(missing_value):
    return IterativeImputer(
        max_iter=1500, tol=15e-4, random_state=None, missing_values=missing_value
    )

In [None]:
# EXP. SETTINGS
runs = 100
sims = 10

assert runs % sims == 0

n = 5000
d = 20
z_d_dim = 10
amount_of_missingness = 0.3
missing_value = -1

learner = "X"
prop_learner = "xgb"

data = "twins"

# DEBUG SETTINGS
verbose = False

In [None]:
ground_truth = []

ATE_impute_all = []
ATE_impute_nothing = []
ATE_impute_smartly = []
ATE_impute_wrongly = []

for _ in tqdm(range(sims)):
    X, X_, Y0, Y1, Y, CATE, W, Z_up, Z_down = generate_data_exp(
        n * 2, d, z_d_dim, amount_of_missingness, missing_value=missing_value, data=data
    )

    for _ in tqdm(range(int(runs / sims)), leave=False):
        idxs = np.random.choice(range(n * 2), size=n, replace=False)

        X__i, Y_i, W_i = X_[idxs], Y[idxs], W[idxs]  # 50% fold -> n=n

        ground_truth.append(Y1.mean() - Y0.mean())

        # IMPUTE ALL
        X_in_use = X__i.copy()
        imputer = get_imputer(missing_value)
        imputer.fit(X_in_use)
        X_in_use = imputer.transform(X_in_use)

        cm_impute_all = learners[learner](missing_value)
        ps = prop_learners[prop_learner](missing_value, X_in_use, W_i)

        ATE_impute_all.append(cm_impute_all.estimate_ate(X_in_use, W_i, Y_i, p=ps)[0])

        if verbose:
            print("all", X_in_use.min())

        # IMPUTE NOTHING
        X_in_use = X__i.copy()

        cm_impute_nothing = learners[learner](missing_value)
        ps = prop_learners[prop_learner](missing_value, X_in_use, W_i)

        ATE_impute_nothing.append(
            cm_impute_nothing.estimate_ate(X_in_use, W_i, Y_i, p=ps)[0]
        )

        if verbose:
            print("nothing", X_in_use.min())

        # IMPUTE SMARTLY
        X_in_use = X__i.copy()
        imputer_smart = get_imputer(missing_value)
        imputer_smart.fit(X_in_use[:, z_d_dim:])

        X_in_use[:, z_d_dim:] = imputer_smart.transform(X_in_use[:, z_d_dim:])

        est_impute_smartly = learners[learner](missing_value)
        ps = prop_learners[prop_learner](missing_value, X_in_use, W_i)

        ATE_impute_smartly.append(
            est_impute_smartly.estimate_ate(X_in_use, W_i, Y_i, p=ps)[0]
        )

        if verbose:
            print("smart down", X_in_use[:, :z_d_dim].min())
            print("smart up", X_in_use[:, z_d_dim:].min())

        # IMPUTE WRONGLY
        X_in_use = X__i.copy()
        imputer_wrongly = get_imputer(missing_value)
        imputer_wrongly.fit(X_in_use[:, :z_d_dim])

        X_in_use[:, :z_d_dim] = imputer_wrongly.transform(X_in_use[:, :z_d_dim])

        est_impute_wrongly = learners[learner](missing_value)
        ps = prop_learners[prop_learner](missing_value, X_in_use, W_i)

        ATE_impute_wrongly.append(
            est_impute_wrongly.estimate_ate(X_in_use, W_i, Y_i, p=ps)[0]
        )

        if verbose:
            print("wrong down", X_in_use[:, :z_d_dim].min())
            print("wrong up", X_in_use[:, z_d_dim:].min())

In [None]:
print("# SETUP")
print(f"# DATA = {data}")
print(f"# learner = {learner}")
print(f"# amount_of_missingness = {amount_of_missingness}")
print(f"# z_d_dim = {z_d_dim}")
print(f"# amount of runs = {runs}")

all_means, all_stds = split_eval(ATE_impute_all, ground_truth, sims)
no_means, no_stds = split_eval(ATE_impute_nothing, ground_truth, sims)
smart_means, smart_stds = split_eval(ATE_impute_smartly, ground_truth, sims)
wrong_means, wrong_stds = split_eval(ATE_impute_wrongly, ground_truth, sims)

print(f"#   ALL IMPUTATION   :\t {all_means.mean()}\t{all_stds.mean()}")
print(f"#   NO IMPUTATION    :\t {no_means.mean()}\t{no_stds.mean()}")
print(f"#   SMART IMPUTATION :\t {smart_means.mean()}\t{smart_stds.mean()}")
print(f"#   WRONG IMPUTATION :\t {wrong_means.mean()}\t{wrong_stds.mean()}")

In [None]:
# SETUP
# DATA = twins
# learner = T
# amount_of_missingness = 0.3
# z_d_dim = 10
# amount of runs = 100
#   ALL IMPUTATION   :	 5.7554849319306065	2.5355438478863563
#   NO IMPUTATION    :	 6.338256012605365	3.156546417385256
#   SMART IMPUTATION :	 3.8482661724632488	2.0178561822657173
#   WRONG IMPUTATION :	 6.530310658947888	3.150978112501435

# SETUP
# DATA = twins
# learner = DR
# amount_of_missingness = 0.3
# z_d_dim = 10
# amount of runs = 100
#   ALL IMPUTATION   :	 8.975225169035607	3.0473058082213367
#   NO IMPUTATION    :	 10.31051638371869	4.182175835788135
#   SMART IMPUTATION :	 5.754394464820508	2.4033580089100903
#   WRONG IMPUTATION :	 9.297950298812191	3.739133997820358


# SETUP
# DATA = twins
# learner = X
# amount_of_missingness = 0.3
# z_d_dim = 10
# amount of runs = 100
#   ALL IMPUTATION   :	 30.792042486211823	8.45980458171202
#   NO IMPUTATION    :	 8.980250054602692	5.661282026773433
#   SMART IMPUTATION :	 8.153414737744438	4.562703649784634
#   WRONG IMPUTATION :	 10.369293354405885	6.097239453615901