Strongly influenced by https://johaupt.github.io/blog/xbcf.html

In [None]:
import numpy as np
from xbcausalforest import XBCF
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
n_observations = 20000
treatment_fraction = 0.5

# Hyperparameters suggested in section 5.2 of
# https://projecteuclid.org/journals/bayesian-analysis/volume-15/issue-3/Bayesian-Regression-Tree-Models-for-Causal-Inference--Regularization-Confounding/10.1214/19-BA1195.full
n_trees_prognostic = 200
n_trees_treatment = 50
alpha_prognostic = .95
alpha_treatment = .25
beta_prognostic = 2
beta_treatment = 3

## 1. Generate data

In [None]:
rng = np.random.default_rng(42)

In [None]:
def generate_covariates(n):
    return rng.multivariate_normal(
        [0, 0], np.array([[1, 0.1], [0.1, 1]]), size=(n,)
    )

In [None]:
def data_generating_process(X, propensity=treatment_fraction):
    n = X.shape[0]
    z = rng.binomial(1, propensity, size=(n,)).astype("int32")
    mu = (
        100
        + 5 * X[:, 0]
        - 5 * X[:, 1]
        + 2 * X[:, 0] ** 2
        + 5 * X[:, 0] * X[:, 1]
    )
    prognostic_effect = np.random.normal(mu, scale=2, size=(n_observations,))
    tau = (
        2
        + 0.5 * X[:, 1] 
        + 0.5 * X[:, 1] ** 2 
        - 0.5 * X[:, 0] ** 2 
        - 3 * (X[:, 0] > 0)
    )
    y = prognostic_effect + z * tau
    return y, z, tau

In [None]:
X = generate_covariates(n_observations)
X_test = generate_covariates(n_observations)

y, z, tau = data_generating_process(X)
_, _, tau_test = data_generating_process(X_test)

## 2. Sanity check of data

In [None]:
y[z==0].mean(), y[z==0].std()

In [None]:
tau.mean(), tau.std()

In [None]:
tau.min(), tau.mean(), tau.max()

In [None]:
fig, ax = plt.subplots()
ax.hist(y[z==0], alpha=.6, bins=40, density=True, label="control")
ax.hist(y[z==1], alpha=.6, bins=40, density=True, label="treatment")
ax.axvline(y[z==0].mean(), label="avg(control outcome)")
ax.axvline(y[z==1].mean(), label="avg(treatment outcome)")
ax.set_xlabel("outcome")
ax.legend()

In [None]:
sns.kdeplot(tau)

In [None]:
fig, axes = plt.subplots(2, 2, sharey=False, sharex=True, figsize=(10,8))
axes[0, 0].scatter(X[z==0, 0], y[z==0])
axes[0, 1].scatter(X[z==0, 1], y[z==0])
axes[1, 0].scatter(X[z==1, 0], tau[z==1])
axes[1, 1].scatter(X[z==1, 1], tau[z==1])
axes[0, 0].set_ylabel("outcome")
axes[1, 0].set_ylabel("treatment effect")
axes[1, 0].set_xlabel("X_1")
axes[1, 1].set_xlabel("X_2")

## 3. Estimation

In [None]:
xbcf = XBCF(
    parallel=True,
    num_sweeps=50,
    burnin=15,
    max_depth=250,
    num_trees_pr=n_trees_prognostic,
    num_trees_trt=n_trees_treatment,
    num_cutpoints=100,
    Nmin=1,
    tau_pr=0.6 * np.var(y) / n_trees_prognostic,
    tau_trt=0.1 * np.var(y) / n_trees_treatment,
    alpha_pr=alpha_prognostic,
    beta_pr=beta_prognostic,
    alpha_trt=alpha_treatment,
    beta_trt=beta_treatment,
    p_categorical_pr=0,
    p_categorical_trt=0,
    standardize_target=True,
)

In [None]:
xbcf.fit(x_t=X, x=X, y=y, z=z)

In [None]:
tau_hat = xbcf.predict(X_test, return_mean=True)

In [None]:
np.mean(tau_test), np.mean(tau_hat)

In [None]:
fig, ax = plt.subplots(figsize=[4,4])
ax.set_xlim([-10,10])
ax.set_ylim([-10,10])
ax.scatter(tau_test, tau_hat, alpha=0.5)
ax.set_xlabel('True effect')
ax.set_ylabel('XBCF prediction');

In [None]:
fig, axes = plt.subplots(1, 2, figsize=[10, 4])
axes[0].scatter(X_test[:, 0], tau_test, label="true value")
axes[0].scatter(X_test[:, 0], tau_hat, alpha=.2, label="estimate")
axes[0].set_xlabel("X_1")
axes[0].set_ylabel("tau")
axes[0].legend()

axes[1].scatter(X_test[:, 1], tau_test, label="true value")
axes[1].scatter(X_test[:, 1], tau_hat, alpha=0.2, label="estimate")
axes[1].set_xlabel("X_2")
axes[1].set_ylabel("tau")
axes[1].legend()

In [None]:
tau_posterior = xbcf.predict(X_test, return_mean=False)[:, xbcf.getParams()['burnin']:]

In [None]:
# 1 - .1 credible interval on tau hat for first 5 units
np.quantile(tau_posterior[:5], [0.05,0.95], axis=1).T