In [None]:
import sys, os

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

import numpy as np

from tqdm.notebook import tqdm

from xgboost import XGBRegressor, XGBClassifier
from econml.dml import NonParamDML, LinearDML
from econml.dr import LinearDRLearner, ForestDRLearner
from econml.metalearners import XLearner
from econml import metalearners

import sklearn
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, IterativeImputer
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.svm import SVC, SVR

from src.data import data_module
from src.data.utils import split_eval_cate

In [None]:
def get_regressor(missing_value):
    return XGBRegressor(missing=missing_value, eval_metric="logloss")


def get_classifier(missing_value):
    return XGBClassifier(
        use_label_encoder=False, missing=missing_value, eval_metric="logloss"
    )


def get_imputer(missing_value):
    return IterativeImputer(
        max_iter=1500, tol=15e-4, random_state=None, missing_values=missing_value
    )
    # return SimpleImputer(missing_values=0, strategy='mean')


learners = {
    "T": lambda missing_value: metalearners.TLearner(
        models=get_regressor(missing_value)
    ),
    "X": lambda missing_value: XLearner(
        models=get_regressor(missing_value),
        propensity_model=get_classifier(missing_value),
        cate_models=get_regressor(missing_value),
    ),
}


def evaluate(ground_truth, estimate, W):
    PEHE = np.sqrt(((estimate - ground_truth) ** 2).mean())
    PEHE_0 = np.sqrt(((estimate[W == 0] - ground_truth[W == 0]) ** 2).mean())
    PEHE_1 = np.sqrt(((estimate[W == 1] - ground_truth[W == 1]) ** 2).mean())
    return PEHE, PEHE_0, PEHE_1

In [None]:
# EXP. SETTINGS
runs = 10
sims = 1

assert runs % sims == 0


train_size = 5000

d = 20
z_d_dim = 10
amount_of_missingness = [0.1, 0.2, 0.3, 0.4, 0.5]
missing_value = -1

# DEBUG SETTINGS
verbose = False


# PREP
amount_of_missingness = (
    amount_of_missingness
    if isinstance(amount_of_missingness, list)
    else [amount_of_missingness]
)
z_d_dim = z_d_dim if isinstance(z_d_dim, list) else [z_d_dim]

In [None]:
# WE FIX X and Y
X = data_module._generate_covariates(d, train_size * 2)
Y0, Y1, CATE = data_module._generate_outcomes(X)

In [None]:
learner = "T"

In [None]:
res_dict = {
    "ground_truth": np.zeros(
        (sims, len(amount_of_missingness), len(z_d_dim), int(runs / sims), train_size)
    ),
    "all": np.zeros(
        (sims, len(amount_of_missingness), len(z_d_dim), int(runs / sims), train_size)
    ),
    "nothing": np.zeros(
        (sims, len(amount_of_missingness), len(z_d_dim), int(runs / sims), train_size)
    ),
    "smartly": np.zeros(
        (sims, len(amount_of_missingness), len(z_d_dim), int(runs / sims), train_size)
    ),
    "wrongly": np.zeros(
        (sims, len(amount_of_missingness), len(z_d_dim), int(runs / sims), train_size)
    ),
}

past_w = np.zeros(
    (sims, len(amount_of_missingness), len(z_d_dim), int(runs / sims), train_size)
)


# i -> sim index
# j -> a_o_m index
# k -> z_d_d index
# l -> run index

# dicts: [name] -> [i, j, k, l, train_size]

for i, _ in enumerate(tqdm(range(sims), desc="Simulation")):
    for j, a_o_m in enumerate(
        tqdm(amount_of_missingness, desc="Missingness", leave=False)
    ):
        for k, z_d_d in enumerate(tqdm(z_d_dim, desc="Z dim", leave=False)):
            assert 10 < train_size < len(X)

            for l, _ in enumerate(
                tqdm(range(int(runs / sims)), leave=False, desc="Runs per sim")
            ):
                # GENERATE MISSINGNESS DEPENDENT VARIABLES
                Z_down = data_module._Z_down(a_o_m, X, z_d_d)
                W = data_module._treatments(Z_down, X, z_d_d)
                Y = data_module._generate_observed_outcomes(Y0, Y1, W)
                Z_up = data_module._Z_up(a_o_m, X, z_d_d, W)
                X_ = data_module._complete_covariates(
                    X, z_d_d, Z_up, Z_down, missing_value
                )

                idxs = np.random.choice(range(len(X)), size=train_size, replace=False)
                include_idx = set(idxs)
                mask = np.array([(i in include_idx) for i in range(len(X))])

                X_train, Y_train, W_train, CATE_train = (
                    X_[mask],
                    Y[mask],
                    W[mask],
                    CATE[mask],
                )
                X_test, Y_test, W_test, CATE_test = (
                    X_[~mask],
                    Y[~mask],
                    W[~mask],
                    CATE[~mask],
                )

                res_dict["ground_truth"][i, j, k, l] = CATE_test
                past_w[i, j, k, l] = W_test

                # IMPUTE ALL
                imputer = get_imputer(missing_value)
                imputer.fit(X_train)
                X_train_preprocessed = imputer.transform(X_train)
                X_test_preprocessed = imputer.transform(X_test)

                est_impute_all = learners[learner](missing_value)
                est_impute_all.fit(Y_train, W_train, X=X_train_preprocessed)

                # PEHE_impute_all.append(evaluate(CATE_test, te, W_test))
                res_dict["all"][i, j, k, l] = est_impute_all.effect(X_test_preprocessed)

                if verbose:
                    print("all", X_train.min())

                # IMPUTE NOTHING
                treatment_effects_impute_nothing = []
                X_train_preprocessed = X_train.copy()
                X_test_preprocessed = X_test.copy()

                est_impute_nothing = learners[learner](missing_value)
                est_impute_nothing.fit(Y_train, W_train, X=X_train_preprocessed)

                # PEHE_impute_nothing.append(evaluate(CATE_test, te, W_test))
                res_dict["nothing"][i, j, k, l] = est_impute_nothing.effect(
                    X_test_preprocessed
                )

                if verbose:
                    print("nothing", X_train.min())

                # IMPUTE SMARTLY
                treatment_effects_impute_smartly = []
                imputer_smart = get_imputer(missing_value)
                imputer_smart.fit(X_train[:, z_d_d:])

                X_train_preprocessed = X_train.copy()
                X_test_preprocessed = X_test.copy()

                X_train_preprocessed[:, z_d_d:] = imputer_smart.transform(
                    X_train[:, z_d_d:]
                )
                X_test_preprocessed[:, z_d_d:] = imputer_smart.transform(
                    X_test[:, z_d_d:]
                )

                est_impute_smartly = learners[learner](missing_value)
                est_impute_smartly.fit(Y_train, W_train, X=X_train_preprocessed)

                # PEHE_impute_smartly.append(evaluate(CATE_test, te, W_test))
                res_dict["smartly"][i, j, k, l] = est_impute_smartly.effect(
                    X_test_preprocessed
                )

                if verbose:
                    print("smart down", X_train[:, :z_d_d].min())
                    print("smart up", X_train[:, z_d_d:].min())

                # IMPUTE WRONGLY
                treatment_effects_impute_wrongly = []
                imputer_wrongly = get_imputer(missing_value)
                imputer_wrongly.fit(X_train[:, :z_d_d])

                X_train_preprocessed = X_train.copy()
                X_test_preprocessed = X_test.copy()

                X_train_preprocessed[:, :z_d_d] = imputer_wrongly.transform(
                    X_train[:, :z_d_d]
                )
                X_test_preprocessed[:, :z_d_d] = imputer_wrongly.transform(
                    X_test[:, :z_d_d]
                )

                est_impute_wrongly = learners[learner](missing_value)
                est_impute_wrongly.fit(Y_train, W_train, X=X_train_preprocessed)

                # PEHE_impute_wrongly.append(evaluate(CATE_test, te, W_test))
                res_dict["wrongly"][i, j, k, l] = est_impute_wrongly.effect(
                    X_test_preprocessed
                )

                if verbose:
                    print("wrong down", X_train[:, :z_d_d].min())
                    print("wrong up", X_train[:, z_d_d:].min())

In [None]:
def sens_eval(setting_key, result_dict, aggr_tuple):
    temp = (result_dict[setting_key] - result_dict["ground_truth"]) ** 2
    swapped = np.swapaxes(temp, 0, 1)

    means = swapped.mean(axis=4).mean(axis=3).mean(axis=aggr_tuple)
    stds = swapped.mean(axis=4).std(axis=3).mean(axis=aggr_tuple)

    return np.round(means, decimals=5), np.round(stds, decimals=4)


for k in res_dict.keys():
    if k != "ground_truth":
        print(
            f'# set: {k} \t {" ".join(str(a) for a in list(zip(*sens_eval(k, res_dict, (2,1)))))}'
        )

In [None]:
for k in res_dict.keys():
    if k != "ground_truth":
        means, stds = sens_eval(k, res_dict, (2, 1))
        print(
            f'# set: {k} \t means: {" ".join(str(a) for a in list(zip(amount_of_missingness, means)))}\n#\t\t\t stds:  {" ".join(str(a) for a in list(zip(amount_of_missingness, stds)))}'
        )

In [None]:
# T-Learner
# set: all 	 means: (0.1, 0.6387) (0.2, 0.97497) (0.3, 1.0933) (0.4, 0.77673) (0.5, 0.6489)
# 		 stds:  (0.1, 0.0267) (0.2, 0.0711) (0.3, 0.0584) (0.4, 0.0589) (0.5, 0.0424)
# set: nothing 	 means: (0.1, 1.10829) (0.2, 1.00463) (0.3, 0.94244) (0.4, 0.88775) (0.5, 0.82824)
# 		 stds:  (0.1, 0.0732) (0.2, 0.0638) (0.3, 0.051) (0.4, 0.0555) (0.5, 0.0632)
# set: smartly 	 means: (0.1, 0.60229) (0.2, 0.56311) (0.3, 0.54543) (0.4, 0.40746) (0.5, 0.35365)
# 		 stds:  (0.1, 0.0527) (0.2, 0.0555) (0.3, 0.0811) (0.4, 0.0371) (0.5, 0.0273)
# set: wrongly 	 means: (0.1, 1.18941) (0.2, 1.10463) (0.3, 1.05624) (0.4, 0.95405) (0.5, 0.87472)
# 		 stds:  (0.1, 0.0787) (0.2, 0.0531) (0.3, 0.061) (0.4, 0.0537) (0.5, 0.0524)


# X-Learner
# set: all 	 means: (0.1, 0.46102) (0.2, 0.79628) (0.3, 0.83915) (0.4, 0.56256) (0.5, 0.4036)
# 		 stds:  (0.1, 0.0389) (0.2, 0.0624) (0.3, 0.0722) (0.4, 0.0522) (0.5, 0.0304)
# set: nothing 	 means: (0.1, 0.03348) (0.2, 0.04115) (0.3, 0.05752) (0.4, 0.05961) (0.5, 0.08526)
# 		 stds:  (0.1, 0.0047) (0.2, 0.0148) (0.3, 0.0272) (0.4, 0.0204) (0.5, 0.0254)
# set: smartly 	 means: (0.1, 0.21161) (0.2, 0.20825) (0.3, 0.22869) (0.4, 0.17132) (0.5, 0.16462)
# 		 stds:  (0.1, 0.0373) (0.2, 0.0283) (0.3, 0.0293) (0.4, 0.0125) (0.5, 0.0135)
# set: wrongly 	 means: (0.1, 0.04181) (0.2, 0.04966) (0.3, 0.11158) (0.4, 0.08371) (0.5, 0.04647)
# 		 stds:  (0.1, 0.0099) (0.2, 0.0138) (0.3, 0.097) (0.4, 0.1261) (0.5, 0.0473)