In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt

import sys
sys.path.append("../../../")

from chiseling.dgps.bart_dataset import BARTDataset
from chiseling.source.strategies.legacy.chiseling_interpretable import ChiselingInterpretable
from chiseling.source.learners.baselearners_causal import make_causal_random_forest_classifier_learner
from chiseling.source.protocol.IRST import UnitRegistrar
from chiseling.source.protocol.utils import aipw_intercept_pseudo_outcome

### Config

In [None]:
BART_PATH = "./bart_dataset_processed.tsv.gz"
RANDOM_SEED = 42

alpha = 0.05
test_thresh_interpret = 0.45
n_min = 30
n_burn_in = 1000
shrink_to_boundary = True

FIGURE_SAVEPATH = "../../../figures/"
SAVE = False

### Run interpretable random forest analysis

In [None]:
np.random.seed(RANDOM_SEED)

In [None]:
# Get the data
bart_sampler = BARTDataset(bart_path=BART_PATH, random_seed=RANDOM_SEED)
# Get BART dataframe
bart_df = bart_sampler.bart_df
bart_df = bart_df.drop("year", axis=1)
# Extract variables
Y = bart_df.loc[:,"outcome"].values
TX = bart_df.iloc[:,1:].values
# Get discrete covariate column indices (10 or fewer levels)
bart_cov_df = bart_df.iloc[:,1:]
discrete_col_inds = bart_cov_df.columns.get_indexer(bart_cov_df.columns[bart_cov_df.nunique(dropna=False).le(10)])
print("n discrete feats = {} / n feats = {}".format(len(discrete_col_inds), bart_cov_df.shape[1]))

In [None]:
# Construct pseudo-outcome
pY = aipw_intercept_pseudo_outcome(TX[:,0],
                                   TX[:,1:],
                                   Y,
                                   bart_sampler.get_propensity(),
                                   random_seed=RANDOM_SEED)
print("mean = {}, se = {}".format(pY.mean(), np.sqrt(pY.var() / pY.shape[0])))
# Subtract off test_thresh_interpret
pY = pY - test_thresh_interpret

In [None]:
# Make the learner
learner = make_causal_random_forest_classifier_learner(random_seed=RANDOM_SEED)

In [None]:
chiseling = ChiselingInterpretable(X=TX,
                                   Y=Y,
                                   test_thresh=test_thresh_interpret,
                                   alpha=alpha,
                                   learner=learner,
                                   n_burn_in=n_burn_in,
                                   pY=pY,
                                   binary=False,
                                   alpha_init=0,
                                   refit_batch_prop=0.05,
                                   reveal_batch_prop=0.01,
                                   margin_width=1,
                                   n_min=n_min,
                                   alpha_min='auto',
                    			   use_learner_weights=False,
                    			   skip_const_predictor=False,
                    			   shrink_to_boundary=shrink_to_boundary,
                    			   tiebreak=False,
                                   min_box_cond_samps=100,
                                   ignored_facets=[0],
                                   discrete_coords=discrete_col_inds,
                    			   random_seed=RANDOM_SEED)

In [None]:
chiseling.run_strategy(verbose=True)

In [None]:
chiseling.protocol.testing_history

In [None]:
chiseling.protocol.protocol_metadata.reg_mass_est.min()

Inspect the rejected region

In [None]:
# Get region
region = chiseling.protocol.get_rejected_region()

In [None]:
# Register units and check region membership
unit_reg = UnitRegistrar(random_seed=RANDOM_SEED)
regTX = unit_reg.register_units(TX)
in_reg_indics = region.in_region(regTX)

In [None]:
# Get original dataframe and subset to region units
region_bart_df = bart_df.loc[in_reg_indics]

In [None]:
# Check subgroup treatment effect
region_bart_df.groupby("treatment").outcome.mean()

In [None]:
# Get bounds of region
box_lb = region_bart_df.iloc[:,2:].min(axis=0).values
box_ub = region_bart_df.iloc[:,2:].max(axis=0).values

In [None]:
# Check whether each feature passes its inclusion creteria
per_feature_inclusion_indics = ((box_lb <= bart_df.iloc[:,2:].values) &
                                (bart_df.iloc[:,2:].values <= box_ub))

In [None]:
# Average inclusion rate
avg_inclusion_rate = per_feature_inclusion_indics.mean(axis=0)

In [None]:
# Visualize average inclusion rate
fig, ax = plt.subplots(1,1,figsize=(6,3.5))

ax.hist(per_feature_inclusion_indics.mean(axis=0), bins=50)
plt.show()

print("Number of inclusion rates < 1:", (per_feature_inclusion_indics.mean(axis=0) < 1).sum())

In [None]:
# Sort features by average inclusion rate
sorted_features = bart_df.columns.values[2:][np.argsort(avg_inclusion_rate)]
sorted_avg_inclusion_rates = np.sort(avg_inclusion_rate)
print(sorted_features[:5])
print(sorted_avg_inclusion_rates[:5])

In [None]:
# See how much of region is replicated by first five features
top5_features = sorted_features[:5]
top5_inclusion_indics = per_feature_inclusion_indics[:,np.isin(bart_df.columns.values[2:],
                                                               top5_features)]
top5_inclusion_rate = top5_inclusion_indics.all(axis=1).mean()
print(top5_inclusion_rate)

In [None]:
# Visualize top 10 features
fig, ax = plt.subplots(2,5,figsize=(24,8))

PANEL_COUNTER = 0
AX0, AX1 = PANEL_COUNTER // 5, PANEL_COUNTER % 5

for i in range(10):
    feat = sorted_features[i]
    # Calculate per feature subgroup ATE
    feat_region_inds = (region_bart_df.loc[:,feat].min() <= bart_df.loc[:,feat])
    feat_region_inds = feat_region_inds & (bart_df.loc[:,feat] <= region_bart_df.loc[:,feat].max())
    feat_region_df = bart_df.loc[feat_region_inds]
    feat_trt_effect_df = feat_region_df.groupby("treatment").outcome.mean()
    feat_trt_effect = feat_trt_effect_df[1] - feat_trt_effect_df[0]
    # Plot
    ax[AX0, AX1].hist(bart_df.loc[:,feat].values, bins=50, density=True)
    ax[AX0, AX1].axvline(region_bart_df.loc[:,feat].min(), color="red")
    ax[AX0, AX1].axvline(region_bart_df.loc[:,feat].max(), color="green")
    ax[AX0, AX1].set_title(feat)
    print("{}: inclusion rate = {}, trt effect = {}".format(feat, sorted_avg_inclusion_rates[i], feat_trt_effect))
    # Increment panel
    PANEL_COUNTER += 1
    AX0, AX1 = PANEL_COUNTER // 5, PANEL_COUNTER % 5
    
plt.show()

In [None]:
# Get indicators for points that were revealed randomly
meta = chiseling.protocol.protocol_metadata
rand_inds = meta[meta.is_random].orig_ind.values
rand_indics = np.isin(np.arange(len(TX)), rand_inds)
nonrand_indics = ~rand_indics

In [None]:
# Calculate the running estimate of intersection subgroup ATE compared to per-feature subgroup ATE
feat_region_inds = True
subgroup_ATEs = []
feat_subgroup_ATEs = []
subgroup_sizes = []
feat_subgroup_sizes = []
for feat in sorted_features:
    # Get the region indicators
    per_feat_region_inds = nonrand_indics
    per_feat_region_inds = per_feat_region_inds & (region_bart_df.loc[:,feat].min() <= bart_df.loc[:,feat])
    per_feat_region_inds = per_feat_region_inds & (bart_df.loc[:,feat] <= region_bart_df.loc[:,feat].max())
    feat_region_inds = feat_region_inds & per_feat_region_inds
    # Calculate ATEs
    per_feat_region_df = bart_df.loc[per_feat_region_inds]
    per_feat_trt_effect_df = per_feat_region_df.groupby("treatment").outcome.mean()
    per_feat_trt_effect = per_feat_trt_effect_df[1] - per_feat_trt_effect_df[0]
    feat_subgroup_ATEs.append(per_feat_trt_effect)
    subgroup_region_df = bart_df.loc[feat_region_inds]
    subgroup_trt_effect_df = subgroup_region_df.groupby("treatment").outcome.mean()
    subgroup_trt_effect = subgroup_trt_effect_df[1] - subgroup_trt_effect_df[0]
    subgroup_ATEs.append(subgroup_trt_effect)
    # Calculate sizes
    subgroup_sizes.append(feat_region_inds.mean())
    feat_subgroup_sizes.append(per_feat_region_inds.mean())
subgroup_ATEs = np.array(subgroup_ATEs)
feat_subgroup_ATEs = np.array(feat_subgroup_ATEs)
subgroup_sizes = np.array(subgroup_sizes)
feat_subgroup_sizes = np.array(feat_subgroup_sizes)

In [None]:
fig, ax = plt.subplots(1,2,figsize=(12,5))

LABEL_FONTSIZE = 18
MARKERSIZE = 25
LINEWIDTH = 1
LINEALPHA = 0.6
TICK_FONTSIZE = 13
LEGEND_FONTSIZE = 13
FULL_COLOR = "#2032DA"
PER_COLOR = "#69C1B9"
FINAL_COLOR = "#2032DA"
ATE_COLOR = "red"

N_FEATURES = 40
FULL_TRT_EFFECT = bart_df.groupby("treatment").outcome.mean()[1] - bart_df.groupby("treatment").outcome.mean()[0]

ax[0].plot(np.arange(1,1+N_FEATURES), subgroup_ATEs[:N_FEATURES], color=FULL_COLOR, linewidth=LINEWIDTH, alpha=LINEALPHA)
ax[0].scatter(np.arange(1,1+N_FEATURES), subgroup_ATEs[:N_FEATURES], s=MARKERSIZE, label="Running Intersection",
              color=FULL_COLOR)
ax[0].plot(np.arange(1,1+N_FEATURES), feat_subgroup_ATEs[:N_FEATURES], color=PER_COLOR, linewidth=LINEWIDTH, alpha=LINEALPHA)
ax[0].scatter(np.arange(1,1+N_FEATURES), feat_subgroup_ATEs[:N_FEATURES], s=MARKERSIZE, label="Single Feature",
              color=PER_COLOR)
ax[0].axhline(subgroup_ATEs[-1], color=FINAL_COLOR, label="Final Subgroup")
ax[0].axhline(FULL_TRT_EFFECT, color=ATE_COLOR, label="Full Population")
ax[0].spines[['right', 'top']].set_visible(False)
ax[0].tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

ax[1].plot(np.arange(1,1+N_FEATURES), subgroup_sizes[:N_FEATURES], color=FULL_COLOR, linewidth=LINEWIDTH, alpha=LINEALPHA)
ax[1].scatter(np.arange(1,1+N_FEATURES), subgroup_sizes[:N_FEATURES], s=MARKERSIZE, label="Running Intersection",
              color=FULL_COLOR)
ax[1].plot(np.arange(1,1+N_FEATURES), feat_subgroup_sizes[:N_FEATURES], color=PER_COLOR, linewidth=LINEWIDTH, alpha=LINEALPHA)
ax[1].scatter(np.arange(1,1+N_FEATURES), feat_subgroup_sizes[:N_FEATURES], s=MARKERSIZE, label="Single Feature",
              color=PER_COLOR)
ax[1].axhline(subgroup_sizes[-1], color=FINAL_COLOR, label="Final Subgroup")
ax[1].spines[['right', 'top']].set_visible(False)
ax[1].tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# Global formatting
ax[0].set_ylabel("Subgroup ATE", fontsize=LABEL_FONTSIZE)
ax[0].set_xlabel("Feature Rank", fontsize=LABEL_FONTSIZE)
ax[1].set_ylabel("Subgroup Size", fontsize=LABEL_FONTSIZE)
ax[1].set_xlabel("Feature Rank", fontsize=LABEL_FONTSIZE)
ax[1].set_ylim(0,1)
ax[0].legend(fontsize=LEGEND_FONTSIZE)
plt.tight_layout()

# Save
if SAVE:
    plt.savefig(FIGURE_SAVEPATH + "bart_interpret_facet.pdf", bbox_inches="tight", dpi=300)

plt.show()