### Probabilistc ERUPT in CATE Setting

Let's suppose we have two treatments for a medical condition, and we want to evaluate their effectiveness for patients segmented by age group. Our goal is to use probabilistic ERUPT to compare different models that estimate CATE.

#### Treatments

- **Treatment A**: Medication A
- **Treatment B**: Medication B

#### Data from Clinical Trials

Here is hypothetical data showing patient outcomes after receiving each treatment, segmented by age group:

| Patient | Age Group | Treatment | Outcome |
|---------|-----------|-----------|---------|
| 1       | Young     | A         | 5       |
| 2       | Young     | A         | 3       |
| 3       | Young     | B         | 2       |
| 4       | Old       | B         | 4       |
| 5       | Old       | A         | 6       |
| 6       | Old       | B         | 5       |

#### Model Predictions

Two different models provide CATE estimates (impact) and uncertainties (standard deviations) for each treatment, segmented by age group.

**Model 1 Estimates:**

| Age Group | Treatment | Estimated Impact | Std Deviation |
|-----------|-----------|------------------|---------------|
| Young     | A         | 4.5              | 1.0           |
| Young     | B         | 3.5              | 0.5           |
| Old       | A         | 5.5              | 0.8           |
| Old       | B         | 4.0              | 0.6           |

**Model 2 Estimates:**

| Age Group | Treatment | Estimated Impact | Std Deviation |
|-----------|-----------|------------------|---------------|
| Young     | A         | 4.0              | 0.8           |
| Young     | B         | 4.0              | 0.8           |
| Old       | A         | 5.0              | 0.7           |
| Old       | B         | 5.0              | 0.7           |

### Probabilistic ERUPT Process in CATE Setting

Here’s how the process can be adapted to evaluate and compare these models in the CATE setting:

1. **Probabilistic Treatment Selection Based on CATE**:
   - We use Thompson sampling to select treatments probabilistically based on the CATE estimates and their uncertainties.

   For each age group, we sample from the distributions defined by the models' estimates.

   **Example of one iteration for each model:**

   - **Model 1**:
     - **Young Group**: Samples 4.2 for A and 3.8 for B — chooses A.
     - **Old Group**: Samples 5.1 for A and 3.7 for B — chooses A.

   - **Model 2**:
     - **Young Group**: Samples 3.5 for A and 4.5 for B — chooses B.
     - **Old Group**: Samples 4.6 for A and 5.2 for B — chooses B.

2. **Simulate the Selection Process**:
   - Perform this sampling many times to simulate different scenarios.
   - Calculate how often each treatment is chosen in each age group.

3. **Calculate ERUPT Scores**:
   - For each model, after many iterations, calculate the average outcome for the treatments chosen based on the probabilistic CATE estimates.

   Historical outcomes based on the table:

   - **Young Group**: Average outcomes for A: 4, for B: 2.
   - **Old Group**: Average outcomes for A: 6, for B: 4.5.

   **Calculated over many iterations:**

   - **Model 1**: 
     - Young Group: More often chooses A.
     - Old Group: More often chooses A.
     - Average outcome = $(0.7 * 4 + 0.3 * 2)_{Young} + (0.8 * 6 + 0.2 * 4.5)_{Old}$
     - $= (0.7 * 4 + 0.3 * 2) + (0.8 * 6 + 0.2 * 4.5)$
     - $= (2.8 + 0.6) + (4.8 + 0.9) = 3.4 + 5.7 = 9.$

   - **Model 2**: 
     - Young Group: More often chooses B.
     - Old Group: More often chooses B.
     - Average outcome = $(0.4 * 4 + 0.6 * 2)_{Young} + (0.4 * 6 + 0.6 * 4.5)_{Old}$
     - $= (0.4 * 4 + 0.6 * 2) + (0.4 * 6 + 0.6 * 4.5)$
     - $= (1.6 + 1.2) + (2.4 + 2.7) = 2.8 + 5.1 = 7.9$

4. **Compare Models**:
   - Compare these average outcomes. The model with the higher average outcome is considered better at using CATE for treatment selection.

### Conclusion

In this example, **Model 1** appears to perform better in utilizing CATE to probabilistically select more effective treatments, based on the outcomes in each segment. This approach shows how probabilistic ERUPT can be a powerful tool for comparing models in terms of their ability to use CATE estimates effectively under uncertainty.

### Suitable Estimators (Provide unertainty naturally)

Naturally:
- Causal Forest DML: Provides uncertainties naturally from the forest structure.
- DR Ortho Forest: Provides standard errors naturally.
- DML Ortho Forest: Also provides standard errors naturally.


With conditions:
- T-Learner: If the underlying models provide standard deviations.
- X-Learner: If the second-stage models provide standard deviations.
- Forest DR Learner: Can provide standard deviations through the variability of forest predictions.
- Linear DR Learner: Can provide standard errors for predictions.
- Sparse Linear DR Learner: Can provide uncertainties through methods like bootstrapping.
- Linear DML: Provides standard errors if the outcome model is linear.
- Sparse Linear DML: Similar to Linear DML with potential bootstrapping.

In [None]:
import copy
import logging
import math
from typing import Optional, Dict, Union, Any, List

import numpy as np
import pandas as pd
from sklearn.preprocessing import QuantileTransformer

from econml.cate_interpreter import SingleTreeCateInterpreter  # noqa F401
from dowhy.causal_estimator import CausalEstimate
from dowhy import CausalModel


from causaltune.thirdparty.causalml import metrics
from causaltune.erupt import ERUPT
from causaltune.utils import treatment_values, psw_joint_weights

import dcor


class DummyEstimator:
    def __init__(
        self, cate_estimate: np.ndarray, effect_intervals: Optional[np.ndarray] = None
    ):
        self.cate_estimate = cate_estimate
        self.effect_intervals = effect_intervals

    def const_marginal_effect(self, X):
        return self.cate_estimate


def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List[str]:
    if problem == "iv":
        metrics = ["energy_distance"]
        if not scores_only:
            metrics.append("ate")
        return metrics
    elif problem == "backdoor":
        print("backdoor")
        if multivalue:
            # TODO: support other metrics for the multivalue case
            return ["energy_distance", "psw_energy_distance"]
        else:
            metrics = [
                "erupt",
                "norm_erupt",
                "prob_erupt",
                "qini",
                "auc",
                # "r_scorer",
                "energy_distance",
                "psw_energy_distance",
            ]
            if not scores_only:
                metrics.append("ate")
            return metrics


class Scorer:
    def __init__(
        self,
        causal_model: CausalModel,
        propensity_model: Any,
        problem: str,
        multivalue: bool,
    ):
        """
        Contains scoring logic for CausalTune.

        Access methods and attributes via `CausalTune.scorer`.

        """

        self.problem = problem
        self.multivalue = multivalue
        self.causal_model = copy.deepcopy(causal_model)

        self.identified_estimand = causal_model.identify_effect(
            proceed_when_unidentifiable=True
        )

        if problem == "backdoor":
            print(
                "Fitting a Propensity-Weighted scoring estimator to be used in scoring tasks"
            )
            treatment_series = causal_model._data[causal_model._treatment[0]]
            # this will also fit self.propensity_model, which we'll also use in self.erupt
            self.psw_estimator = self.causal_model.estimate_effect(
                self.identified_estimand,
                method_name="backdoor.causaltune.models.MultivaluePSW",
                control_value=0,
                treatment_value=treatment_values(treatment_series, 0),
                target_units="ate",  # condition used for CATE
                confidence_intervals=False,
                method_params={
                    "init_params": {"propensity_model": propensity_model},
                },
            ).estimator

            treatment_name = self.psw_estimator._treatment_name
            if not isinstance(treatment_name, str):
                treatment_name = treatment_name[0]

            # No need to call self.erupt.fit() as propensity model is already fitted
            # self.propensity_model = est.propensity_model
            self.erupt = ERUPT(
                treatment_name=treatment_name,
                propensity_model=self.psw_estimator.estimator.propensity_model,
                X_names=self.psw_estimator._effect_modifier_names
                + self.psw_estimator._observed_common_causes_names,
            )

    def ate(self, df: pd.DataFrame) -> tuple:
        """Calculate the Average Treatment Effect. Provide naive std estimates in single-treatment cases.

        Args:
            df (pandas.DataFrame): input dataframe

        Returns:
            tuple: tuple containing the ATE, standard deviation of the estimate (or None if multi-treatment),
                and sample size (or None if estimate has more than one dimension)

        """

        estimate = self.psw_estimator.estimator.effect(df).mean(axis=0)

        if len(estimate) == 1:
            # for now, let's cheat on the std estimation, take that from the naive ate
            treatment_name = self.causal_model._treatment[0]
            outcome_name = self.causal_model._outcome[0]
            naive_est = Scorer.naive_ate(df[treatment_name], df[outcome_name])
            return estimate[0], naive_est[1], naive_est[2]
        else:
            return estimate, None, None

    def resolve_metric(self, metric: str) -> str:
        """Check if supplied metric is supported. If not, default to 'energy_distance'.

        Args:
            metric (str): evaluation metric

        Returns:
            str: metric/'energy_distance'

        """

        metrics = supported_metrics(self.problem, self.multivalue, scores_only=True)

        if metric not in metrics:
            logging.warning(
                f"Using energy_distance metric as {metric} is not in the list "
                f"of supported metrics for this usecase ({str(metrics)})"
            )
            return "energy_distance"
        else:
            return metric

    def resolve_reported_metrics(
        self, metrics_to_report: Union[List[str], None], scoring_metric: str
    ) -> List[str]:
        """Check if supplied reporting metrics are valid.

        Args:
            metrics_to_report (Union[List[str], None]): list of strings specifying the evaluation metrics to compute.
                Possible options include 'ate', 'erupt', 'norm_erupt', 'qini', 'auc',
                'energy_distance' and 'psw_energy_distance'.
            scoring_metric (str): specified metric

        Returns:
            List[str]: list of valid metrics
        """

        metrics = supported_metrics(self.problem, self.multivalue, scores_only=False)
        if metrics_to_report is None:
            return metrics
        else:
            metrics_to_report = sorted(list(set(metrics_to_report + [scoring_metric])))
            for m in metrics_to_report.copy():
                if m not in metrics:
                    logging.warning(
                        f"Dropping the metric {m} for problem: {self.problem} \
                        : must be one of {metrics}"
                    )
                    metrics_to_report.remove(m)
        return metrics_to_report

    @staticmethod
    def energy_distance_score(
        estimate: CausalEstimate,
        df: pd.DataFrame,
    ) -> float:
        """Calculate energy distance score between treated and controls.
        For theoretical details, see Ramos-Carreño and Torrecilla (2023).

        Args:
            estimate (dowhy.causal_estimator.CausalEstimate): causal estimate to evaluate
            df (pandas.DataFrame): input dataframe

        Returns:
            float: energy distance score

        """

        Y0X, _, split_test_by = Scorer._Y0_X_potential_outcomes(estimate, df)

        YX_1 = Y0X[Y0X[split_test_by] == 1]
        YX_0 = Y0X[Y0X[split_test_by] == 0]
        select_cols = estimate.estimator._effect_modifier_names + ["yhat"]

        energy_distance_score = dcor.energy_distance(
            YX_1[select_cols], YX_0[select_cols]
        )

        return energy_distance_score

    @staticmethod
    def _Y0_X_potential_outcomes(estimate: CausalEstimate, df: pd.DataFrame):
        est = estimate.estimator
        # assert est.identifier_method in ["iv", "backdoor"]
        treatment_name = (
            est._treatment_name
            if isinstance(est._treatment_name, str)
            else est._treatment_name[0]
        )
        df["dy"] = estimate.estimator.effect_tt(df)
        df["yhat"] = df[est._outcome_name] - df["dy"]

        split_test_by = (
            est.estimating_instrument_names[0]
            if est.identifier_method == "iv"
            else treatment_name
        )

        Y0X = copy.deepcopy(df)
        return Y0X, treatment_name, split_test_by

    def psw_energy_distance(
        self,
        estimate: CausalEstimate,
        df: pd.DataFrame,
        normalise_features=False,
    ) -> float:
        """
        Calculate propensity score adjusted energy distance score between treated and controls.

        Features are normalised using the sklearn.preprocessing.QuantileTransformer

        For theoretical details, see Ramos-Carreño and Torrecilla (2023).

        @param estimate (dowhy.causal_estimator.CausalEstimate): causal estimate to evaluate
        @param df (pandas.DataFrame): input dataframe
        @param normalise_features (bool): whether to normalise features with QuantileTransformer

        @return float: propensity-score weighted energy distance score

        """

        Y0X, treatment_name, split_test_by = Scorer._Y0_X_potential_outcomes(
            estimate, df
        )

        Y0X_1 = Y0X[Y0X[split_test_by] == 1]
        Y0X_0 = Y0X[Y0X[split_test_by] == 0]

        YX_1_all_psw = self.psw_estimator.estimator.propensity_model.predict_proba(
            Y0X_1[
                self.causal_model.get_effect_modifiers()
                + self.causal_model.get_common_causes()
            ]
        )
        treatment_series = Y0X_1[treatment_name]

        YX_1_psw = np.zeros(YX_1_all_psw.shape[0])
        for i in treatment_series.unique():
            YX_1_psw[treatment_series == i] = YX_1_all_psw[:, i][treatment_series == i]

        YX_0_psw = self.psw_estimator.estimator.propensity_model.predict_proba(
            Y0X_0[
                self.causal_model.get_effect_modifiers()
                + self.causal_model.get_common_causes()
            ]
        )[:, 0]

        select_cols = estimate.estimator._effect_modifier_names + ["yhat"]
        features = estimate.estimator._effect_modifier_names

        xy_psw = psw_joint_weights(YX_1_psw, YX_0_psw)
        xx_psw = psw_joint_weights(YX_0_psw)
        yy_psw = psw_joint_weights(YX_1_psw)

        xy_mean_weights = np.mean(xy_psw)
        xx_mean_weights = np.mean(xx_psw)
        yy_mean_weights = np.mean(yy_psw)

        if normalise_features:
            qt = QuantileTransformer(n_quantiles=200)
            X_quantiles = qt.fit_transform(Y0X[features])

            Y0X_transformed = pd.DataFrame(
                X_quantiles, columns=features, index=Y0X.index
            )
            Y0X_transformed.loc[:, ["yhat", split_test_by]] = Y0X[
                ["yhat", split_test_by]
            ]

            Y0X_1 = Y0X_transformed[Y0X_transformed[split_test_by] == 1]
            Y0X_0 = Y0X_transformed[Y0X_transformed[split_test_by] == 0]

        exponent = 1
        distance_xy = np.reciprocal(xy_mean_weights) * np.multiply(
            xy_psw,
            dcor.distances.pairwise_distances(
                Y0X_1[select_cols], Y0X_0[select_cols], exponent=exponent
            ),
        )
        distance_yy = np.reciprocal(yy_mean_weights) * np.multiply(
            yy_psw,
            dcor.distances.pairwise_distances(Y0X_1[select_cols], exponent=exponent),
        )
        distance_xx = np.reciprocal(xx_mean_weights) * np.multiply(
            xx_psw,
            dcor.distances.pairwise_distances(Y0X_0[select_cols], exponent=exponent),
        )
        psw_energy_distance = (
            2 * np.mean(distance_xy) - np.mean(distance_xx) - np.mean(distance_yy)
        )
        return psw_energy_distance

    @staticmethod
    def qini_make_score(
        estimate: CausalEstimate, df: pd.DataFrame, cate_estimate: np.ndarray
    ) -> float:
        """Calculate the Qini score, defined as the area between the Qini curves of a model and random.

        Args:
            estimate (dowhy.causal_estimator.CausalEstimate): causal estimate to evaluate
            df (pandas.DataFrame): input dataframe
            cate_estimate (np.ndarray): array with cate estimates

        Returns:
            float: Qini score

        """

        est = estimate.estimator
        new_df = pd.DataFrame()
        new_df["y"] = df[est._outcome_name]
        treatment_name = est._treatment_name
        if not isinstance(treatment_name, str):
            treatment_name = treatment_name[0]
        new_df["w"] = df[treatment_name]
        new_df["model"] = cate_estimate

        qini_score = metrics.qini_score(new_df)

        return qini_score["model"]

    @staticmethod
    def auc_make_score(
        estimate: CausalEstimate, df: pd.DataFrame, cate_estimate: np.ndarray
    ) -> float:
        """Calculate the area under the uplift curve.

        Args:
            estimate (dowhy.causal_estimator.CausalEstimate): causal estimate to evaluate
            df (pandas.DataFrame): input dataframe
            cate_estimate (np.ndarray): array with cate estimates

        Returns:
            float: area under the uplift curve

        """

        est = estimate.estimator
        new_df = pd.DataFrame()
        new_df["y"] = df[est._outcome_name]
        treatment_name = est._treatment_name
        if not isinstance(treatment_name, str):
            treatment_name = treatment_name[0]
        new_df["w"] = df[treatment_name]
        new_df["model"] = cate_estimate

        auc_score = metrics.auuc_score(new_df)

        return auc_score["model"]

    @staticmethod
    def real_qini_make_score(
        estimate: CausalEstimate, df: pd.DataFrame, cate_estimate: np.ndarray
    ) -> float:
        # TODO  To calculate the 'real' qini score for synthetic datasets, to be done

        # est = estimate.estimator
        new_df = pd.DataFrame()

        # new_df['tau'] = [df['y_factual'] - df['y_cfactual']]
        new_df["model"] = cate_estimate

        qini_score = metrics.qini_score(new_df)

        return qini_score["model"]

    @staticmethod
    def r_make_score(
        estimate: CausalEstimate, df: pd.DataFrame, cate_estimate: np.ndarray, r_scorer
    ) -> float:
        """Calculate r_score.
        For details refer to Nie and Wager (2017) and Schuler et al. (2018). Adaption from EconML implementation.

        Args:
            estimate (dowhy.causal_estimator.CausalEstimate): causal estimate to evaluate
            df (pandas.DataFrame): input dataframe
            cate_estimate (np.ndarray): array with cate estimates
            r_scorer: callable object used to compute the R-score

        Returns:
            float: r_score

        """

        # TODO
        return r_scorer.score(cate_estimate)

    @staticmethod
    def naive_ate(treatment: pd.Series, outcome: pd.Series):
        """Calculate simple ATE.

        Args:
            treatment (pandas.Series): series of treatments
            outcome (pandas.Series): series of outcomes

        Returns:
            tuple: tuple of simple ATE, standard deviation, and sample size

        """

        treated = (treatment == 1).sum()

        mean_ = outcome[treatment == 1].mean() - outcome[treatment == 0].mean()
        std1 = outcome[treatment == 1].std() / (math.sqrt(treated) + 1e-3)
        std2 = outcome[treatment == 0].std() / (
            math.sqrt(len(outcome) - treated) + 1e-3
        )
        std_ = math.sqrt(std1 * std1 + std2 * std2)
        return (mean_, std_, len(treatment))

    def group_ate(
        self, df: pd.DataFrame, policy: Union[pd.DataFrame, np.ndarray]
    ) -> pd.DataFrame:
        """Compute the average treatment effect (ATE) for different groups specified by a policy.

        Args:
            df (pandas.DataFrame): input dataframe, should contain columns for the treatment, outcome, and policy
            policy (Union[pd.DataFrame, np.ndarray]): policy column in df or an array of the policy values,
                used to group the data

        Returns:
            pandas.DataFrame: ATE, std, and size per policy

        """

        tmp = {"all": self.ate(df)}
        for p in sorted(list(policy.unique())):
            tmp[p] = self.ate(df[policy == p])

        tmp2 = [
            {"policy": str(p), "mean": m, "std": s, "count": c}
            for p, (m, s, c) in tmp.items()
        ]

        return pd.DataFrame(tmp2)

    def make_scores(
        self,
        estimate: CausalEstimate,
        df: pd.DataFrame,
        metrics_to_report: List[str],
        r_scorer=None,
    ) -> dict:
        """Calculate various performance metrics for a given causal estimate using a given DataFrame.

        Args:
            estimate (dowhy.causal_estimator.CausalEstimate): causal estimate to evaluate
            df (pandas.DataFrame): input dataframe
            metrics_to_report (List[str]): list of strings specifying the evaluation metrics to compute.
                Possible options include 'ate', 'erupt', 'norm_erupt', 'qini', 'auc',
                'energy_distance' and 'psw_energy_distance'.
            r_scorer (Optional): callable object used to compute the R-score, default is None

        Returns:
            dict: dictionary containing the evaluation metrics specified in metrics_to_report.
                The values key in the dictionary contains the input DataFrame with additional columns for
                the propensity scores, the policy, the normalized policy, and the weights, if applicable.
        """

        out = dict()
        df = df.copy().reset_index()

        est = estimate.estimator
        treatment_name = est._treatment_name
        if not isinstance(treatment_name, str):
            treatment_name = treatment_name[0]
        outcome_name = est._outcome_name

        cate_estimate = est.effect(df)

        # TODO: fix this hack with proper treatment of multivalues
        if len(cate_estimate.shape) > 1 and cate_estimate.shape[1] == 1:
            cate_estimate = cate_estimate.reshape(-1)

        # TODO: fix this, currently broken
        # covariates = est._effect_modifier_names
        # Include CATE Interpereter for both IV and CATE models
        # intrp = SingleTreeCateInterpreter(
        #     include_model_uncertainty=False, max_depth=2, min_samples_leaf=10
        # )
        # intrp.interpret(DummyEstimator(cate_estimate), df[covariates])
        # intrp.feature_names = covariates
        # out["intrp"] = intrp

        if self.problem == "backdoor":
            values = df[[treatment_name, outcome_name]]
            simple_ate = self.ate(df)[0]
            if isinstance(simple_ate, float):
                # simple_ate = simple_ate[0]
                # .reset_index(drop=True)
                values[
                    "p"
                ] = self.psw_estimator.estimator.propensity_model.predict_proba(
                    df[
                        self.causal_model.get_effect_modifiers()
                        + self.causal_model.get_common_causes()
                    ]
                )[
                    :, 1
                ]
                values["policy"] = cate_estimate > 0
                values["norm_policy"] = cate_estimate > simple_ate
                values["weights"] = self.erupt.weights(df, lambda x: cate_estimate > 0)
            else:
                pass
                # TODO: what do we do here if multiple treatments?

            if "erupt" in metrics_to_report:
                erupt_score = self.erupt.score(df, df[outcome_name], cate_estimate > 0)
                out["erupt"] = erupt_score

            if "norm_erupt" in metrics_to_report:
                norm_erupt_score = (
                    self.erupt.score(df, df[outcome_name], cate_estimate > simple_ate)
                    - simple_ate * values["norm_policy"].mean()
                )
                out["norm_erupt"] = norm_erupt_score

            if "prob_erupt" in metrics_to_report:
                treatment_effects = pd.Series(cate_estimate, index=df.index)
                treatment_std_devs = pd.Series(cate_estimate.std(), index=df.index)
                prob_erupt_score = self.erupt.probabilistic_erupt_score(df, df[outcome_name], treatment_effects, treatment_std_devs)
                out["prob_erupt"] = prob_erupt_score

            if "qini" in metrics_to_report:
                out["qini"] = Scorer.qini_make_score(estimate, df, cate_estimate)

            if "auc" in metrics_to_report:
                out["auc"] = Scorer.auc_make_score(estimate, df, cate_estimate)

            if r_scorer is not None:
                out["r_score"] = Scorer.r_make_score(
                    estimate, df, cate_estimate, r_scorer
                )

            # values = values.rename(columns={treatment_name: "treated"})
            assert len(values) == len(df), "Index weirdness when adding columns!"
            values = values.copy()
            out["values"] = values

        if "ate" in metrics_to_report:
            out["ate"] = cate_estimate.mean()
            out["ate_std"] = cate_estimate.std()

        if "energy_distance" in metrics_to_report:
            out["energy_distance"] = Scorer.energy_distance_score(estimate, df)

        if "psw_energy_distance" in metrics_to_report:
            out["psw_energy_distance"] = self.psw_energy_distance(
                estimate,
                df,
            )

        del df
        return out

    @staticmethod
    def best_score_by_estimator(
        scores: Dict[str, dict], metric: str
    ) -> Dict[str, dict]:
        """Obtain best score for each estimator.

        Args:
            scores (Dict[str, dict]): CausalTune.scores dictionary
            metric (str): metric of interest

        Returns:
            Dict[str, dict]: dictionary containing best score by estimator

        """

        for k, v in scores.items():
            if "estimator_name" not in v:
                raise ValueError(
                    f"Malformed scores dict, 'estimator_name' field missing in {k}, {v}"
                )

        estimator_names = sorted(
            list(
                set(
                    [
                        v["estimator_name"]
                        for v in scores.values()
                        if "estimator_name" in v
                    ]
                )
            )
        )
        best = {}
        for name in estimator_names:
            est_scores = [
                v
                for v in scores.values()
                if "estimator_name" in v and v["estimator_name"] == name
            ]
            best[name] = (
                min(est_scores, key=lambda x: x[metric])
                if metric in ["energy_distance", "psw_energy_distance"]
                else max(est_scores, key=lambda x: x[metric])
            )

        return best


In [None]:
from typing import Callable, List, Optional, Union
import copy

import pandas as pd
import numpy as np
import copy
import numpy as np
import pandas as pd
from typing import Callable, List, Optional, Union

# implementation of https://papers.ssrn.com/sol3/papers.cfm?abstract_id=3111957
# we assume treatment takes integer values from 0 to n


class DummyPropensity:
    def __init__(self, p: pd.Series, treatment: pd.Series):
        n_vals = max(treatment) + 1
        out = np.zeros((len(treatment), n_vals))
        for i, pp in enumerate(p.values):
            out[i, treatment.values[i]] = pp
        self.p = out

    def fit(self, *args, **kwargs):
        pass

    def predict_proba(self):
        return self.p


class ERUPT:
    def __init__(
        self,
        treatment_name: str,
        propensity_model,
        X_names: Optional[List[str]] = None,
        clip: float = 0.05,
        remove_tiny: bool = True,
    ):
        self.treatment_name = treatment_name
        self.propensity_model = copy.deepcopy(propensity_model)
        self.X_names = X_names
        self.clip = clip
        self.remove_tiny = remove_tiny

    def fit(self, df: pd.DataFrame):
        if self.X_names is None:
            self.X_names = [c for c in df.columns if c != self.treatment_name]
        self.propensity_model.fit(df[self.X_names], df[self.treatment_name])

    def score(
        self, df: pd.DataFrame, outcome: pd.Series, policy: Callable
    ) -> pd.Series:
        w = self.weights(df, policy)
        return (w * outcome).mean()

    def weights(
        self, df: pd.DataFrame, policy: Union[Callable, np.ndarray, pd.Series]
    ) -> pd.Series:
        W = df[self.treatment_name].astype(int)
        assert all(
            [x >= 0 for x in W.unique()]
        ), "Treatment values must be non-negative integers"

        if callable(policy):
            policy = policy(df).astype(int)
        if isinstance(policy, pd.Series):
            policy = policy.values
        policy = np.array(policy)

        d = pd.Series(index=df.index, data=policy)
        assert all(
            [x >= 0 for x in d.unique()]
        ), "Policy values must be non-negative integers"

        if isinstance(self.propensity_model, DummyPropensity):
            p = self.propensity_model.predict_proba()
        else:
            p = self.propensity_model.predict_proba(df[self.X_names])
        p = np.maximum(p, 1e-4)

        weight = np.zeros(len(df))

        for i in W.unique():
            weight[W == i] = 1 / p[:, i][W == i]

        weight[d != W] = 0.0

        if self.remove_tiny:
            weight[weight > 1 / self.clip] = 0.0
        else:
            weight[weight > 1 / self.clip] = 1 / self.clip

        weight *= len(df) / sum(weight)
        assert not np.isnan(weight.sum()), "NaNs in ERUPT weights"

        return pd.Series(index=df.index, data=weight)

    def probabilistic_erupt_score(
        self, df: pd.DataFrame, outcome: pd.Series, treatment_effects: pd.Series, treatment_std_devs: pd.Series, iterations: int = 1000
    ) -> float:
        unique_treatments = df[self.treatment_name].unique()
        treatment_scores = {treatment: [] for treatment in unique_treatments}

        for _ in range(iterations):
            sampled_effects = {
                treatment: np.random.normal(treatment_effects.loc[treatment], treatment_std_devs.loc[treatment])
                for treatment in unique_treatments
            }
            chosen_treatment = max(sampled_effects, key=sampled_effects.get)
            # Compute weighted outcome
            weights = self.weights(df, lambda x: np.array([chosen_treatment] * len(x)))
            mean_outcome = (weights * outcome).sum() / weights.sum()
            treatment_scores[chosen_treatment].append(mean_outcome)

        average_outcomes = np.mean([np.mean(scores) for scores in treatment_scores.values() if scores])

        return average_outcomes