# MOUD Case Study
We apply ROOT to the medication for opioid use (MOUD) data where we are interested in transporting the treatment effect from the Starting Treatment With Agonist Replacement Therapies (START) trial to the population of individuals in the US seeking treatment for opioid use disorder, using the Treatment Episode Dataset-Admissions 2015-2017. We are interested in (i) the characteristics of the subpopulation for which we can precisely estimate the TATE using the trial evidence, (ii) the TATE estimate for this subpopulation, and (iii) the characteristics identifying those who are underrepresented in the trial cohort.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.tree as tree
from sklearn.impute import KNNImputer
import importlib

#### main packages
import learn_w as learn
import black

importlib.reload(learn)

import warnings

warnings.filterwarnings("ignore")

%load_ext jupyter_black
sns.set(font_scale=1.25, style="whitegrid")
np.random.seed(0)

## Fetching the Data
Reading the data from CSV

In [None]:
outcome_cols = ["opioiduse12", "opioiduse24"] # defining outcome columns (12wk and 24wk outcomes)
treatment_col = "medicine_assigned" # defining treatment column
discrete_cov = ["xrace", "mar", "sex"] # declaring which covariate columns are discrete

## reading and processing CTN 27 and TEDS-A data

baseline_harmonized = pd.read_csv(
    "/Users/harshparikh/Library/CloudStorage/OneDrive-JohnsHopkins/MOUD_data/updated_data/ctn0094/drv/clean_patients_with_relapse_wide.csv",
    index_col=0,
)

stacked_list = []
for i in range(1, 6):
    stacked_list.append(
        pd.read_csv(
            "/Users/harshparikh/Library/CloudStorage/OneDrive-JohnsHopkins/MOUD_data/stacked_list_%d.csv"
            % (i),
            index_col=0,
        )
    )
df = stacked_list[0]

df_tedsa = df.loc[df["trialdata"] == 0] # creating a dataframe for TEDS-A data

### Combining TEDS-A + CTN 27 dataset

In [None]:
np.random.seed(42) # setting the seed for reproducability

ct94 = baseline_harmonized.loc[(baseline_harmonized["project"] == 27)] # creating a dataframe for MOUD trial data CTN 27
outcome94 = ct94[outcome_cols] # outcomes

common_cols = set.intersection(set(df_tedsa.columns), set(ct94.columns)) # limiting to common columns across trial and target datasets

ct94_cc = ct94[common_cols].drop(columns=["edu", "mar"]) # dropping education and marital status from data due to lot of missingness

ct94_cc["sex"] = (ct94["sex"] == "male").astype(int)  # male = 1 and female = 0

imputer = KNNImputer(n_neighbors=4, weights="distance", add_indicator=False)
ct94_cc_imputed = imputer.fit_transform(ct94_cc) 
ct94_cc = pd.DataFrame(ct94_cc_imputed, index=ct94_cc.index, columns=ct94_cc.columns) # imputing missingness in rest of the trial data

ct94_cc["med_met"] = (ct94[treatment_col] == "met").astype(
    int
)  # methadone = 1 and bupenorphine = 0 ; binarizing treatment
ct94_cc = ct94_cc.dropna() # dropping any missing values (which there should not be any)

ct94_cc["S"] = 1 # setting sample indicator 
ct94_cc = ct94_cc.round(0).astype(int) # setting datatype of each column to int

ct94_cc = ct94_cc.join(outcome94, how="inner") # joining outcome column with the data
ct94_cc.groupby(by="med_met").mean()[outcome_cols] # looking at avg outcomes per treatment arm

In [None]:
df_tedsa_cc = df_tedsa[common_cols].drop(columns=["edu", "mar"]) # dropping education and marital status from data due to lot of missingness
df_tedsa_cc["S"] = 0  # setting sample indicator 
# invert age categories (we need to do this but because of this we will not consider underrepresentation on age)
df_tedsa_cc["age"].replace(
    {
        1: 13,
        2: 16,
        3: 18,
        4: 22,
        5: 27,
        6: 32,
        7: 37,
        8: 42,
        9: 47,
        10: 52,
        11: 60,
        12: 68,
    },
    inplace=True,
)

In [None]:
df_primary = pd.concat([df_tedsa_cc.sample(frac=1, replace=False), ct94_cc]) # merging target and trial cohorts 
df_ = df_primary.drop(columns=[outcome_cols[0]]).fillna(0) # outcome data is missing in the target sample -- we are just filling it with 0's however we are never going to use this outcome; python somehow does not like missing values

In [None]:
outcome = outcome_cols[1]
treatment = "med_met"
sample = "S"
data = df_
S = df_[sample]  # indicator for the sample
Y = df_[outcome]  # outcome variable
T = df_[treatment]  # indicator for the treatment

data_dummy = pd.get_dummies(data, columns=["xrace"]) # dummifying race

# renaming variables to something meaningful
data_dummy.rename(
    columns={
        "sex": "Male",
        "age": "Age",
        "ivdrug": "IV Drug Use",
        "bamphetamine30_base": "Hx Amphetamine",
        "bbenzo30_base": "Hx Benzo",
        "bcannabis30_base": "Hx Cannabis",
        "xrace_1": "White",
        "xrace_2": "Black",
        "xrace_3": "Hispanic",
        "xrace_4": "Other Race",
    },
    inplace=True,
)

X = data_dummy.drop(columns=[outcome, treatment, sample])  # pre-treatment covariates

latex_table = data_dummy.groupby(by=[sample]).mean().T.round(4).to_latex() # printing summary of covariates disaggregated by sample indicator

## Estimate RCT-ATE and Target-ATE

In [None]:
importlib.reload(learn)
np.random.seed(42) # okay, I am setting the seed again for reproducability 
df_v, pi, pi_m, e_m, data2 = learn.estimate_ipw(data_dummy, outcome, treatment, sample) # running IPW estimation. This returns a dataframe unit specific weighted outcome, along with selection score per unit and models for selection and propensity score

In [None]:
print(
    "RCT-ATE: %.2f ± %.2f"
    % (
        100
        * (
            df_.loc[(df_[sample] == 1) * (df_[treatment] == 1), outcome].mean()
            - df_.loc[(df_[sample] == 1) * (df_[treatment] == 0), outcome].mean()
        ),
        100
        * (
            df_.loc[(df_[sample] == 1) * (df_[treatment] == 1), outcome].sem()
            + df_.loc[(df_[sample] == 1) * (df_[treatment] == 0), outcome].sem()
        ),
    )
)


print(
    "Transported ATE: %.2f ± %.2f" % (100 * df_v["te"].mean(), 100 * df_v["te"].sem())
)

## Characterizing Underrepresented Population via Selection Score

In [None]:
np.random.seed(42)
data_dummy_logit = data_dummy.copy(deep=True)
data_dummy_logit["pi(x)"] = pi_m.predict_proba(X)[:, 1]
data_dummy_logit["pi(x)/pi"] = data_dummy_logit["pi(x)"] / data_dummy_logit["S"].mean()
# exp_te_m = en.AdaBoostRegressor().fit(data2[X.drop(columns=["Age"]).columns], df_v["a"])

data_dummy_logit["pi(x)/pi"] = data_dummy_logit["pi(x)"] / data_dummy_logit["S"].mean()

### Plotting selection scores per study samples

In [None]:
fig, ax = plt.subplots(sharex=True, figsize=(10, 3), dpi=600)
sns.set(font_scale=1.8, style="whitegrid")
sns.violinplot(
    data=data_dummy_logit,
    x="pi(x)/pi",
    y="S",
    hue="S",
    split=True,
    orient="h",
    ax=ax,
    alpha=0.5,
    inner="quart",
    bw=0.25,
    palette="Set1",
)
plt.xlabel(r"$\ell(x)/{\ell}$")
plt.tight_layout()
plt.savefig("selection_logit.pdf")

data_dummy_logit = data_dummy_logit.drop(columns=X.columns, errors="ignore")
data_dummy_logit.mean()

### Ad-hoc $w(x) = \mathbf{1}[ \ell(x)/\ell > 0.87 ]$

In [None]:
data_dummy_logit["w (predefined threshold)"] = (
    (
        (data_dummy_logit["pi(x)"] / S.mean())
        / ((1 - data_dummy_logit["pi(x)"]) / (1 - S.mean()))
    )
    > 0.87
).astype(int)
data_dummy_logit.mean()

print(
    r"Post Pruning, ATTE: %.3f ± %.3f"
    % (
        100 * df_v["te"].loc[data_dummy_logit["w (predefined threshold)"] == 1].mean(),
        100 * df_v["te"].loc[data_dummy_logit["w (predefined threshold)"] == 1].sem(),
    )
)

data_dummy_logit["pi(x)/pi"] = data_dummy_logit["pi(x)"] / data_dummy_logit["S"].mean()

### Modeling $w(x) = \mathbf{1}[\pi(x)/\pi > a^\star]$ and finding optimal $a^\star$

In [None]:
np.random.seed(42)
a_vals = np.linspace(0.83, 0.86, num=5000)


def obj(a):
    data_dummy_logit["w (optimal threshold)"] = (
        data_dummy_logit["pi(x)/pi"] > a
    ).astype(int)
    val = 100 * df_v["te"].loc[data_dummy_logit["w (optimal threshold)"] == 1].sem()
    return val


objs = [obj(a) for a in a_vals]

plt.axhline(100 * df_v["te"].sem())
sns.lineplot(x=a_vals, y=objs)

data_dummy_logit["w (optimal threshold)"] = (
    data_dummy_logit["pi(x)/pi"] > a_vals[np.argmin(objs)]
).astype(int)

a_vals[np.argmin(objs)], objs[np.argmin(objs)], data_dummy_logit[
    "w (optimal threshold)"
].loc[data_dummy_logit["S"] == 1].mean()

### Printing treatment effects estimates

In [None]:
print(
    r"Pre Pruning, ATTE: %.3f ± %.3f"
    % (
        100 * df_v["te"].mean(),
        100 * df_v["te"].sem(),
    )
)

print(
    r"Post Pruning (using predefined $\pi(x)$ threshold), ATTE: %.3f ± %.3f"
    % (
        100
        * df_v["te"]
        .loc[data_dummy_logit["w (predefined threshold)"].astype(int) == 1]
        .mean(),
        100
        * df_v["te"]
        .loc[data_dummy_logit["w (predefined threshold)"].astype(int) == 1]
        .sem(),
    )
)

print(
    r"Post Pruning (using optimal $\pi(x)$ threshold), ATTE: %.3f ± %.3f"
    % (
        100
        * df_v["te"]
        .loc[data_dummy_logit["w (optimal threshold)"].astype(int) == 1]
        .mean(),
        100
        * df_v["te"]
        .loc[data_dummy_logit["w (optimal threshold)"].astype(int) == 1]
        .sem(),
    )
)

## Characterizing Underrepresented Population via Indicator $w(x)$
$w(x) = \sum_i w_i \mathbf{1}[x=X_i]$

In [None]:
importlib.reload(learn)
# D, w_tree, testing_data = learn.tree_opt(data, outcome, treatment, sample, leaf_proba=1)
np.random.seed(42)
D_labels, f, testing_data = learn.kmeans_opt(
    data=data_dummy,
    outcome=outcome,
    k=400,
    treatment=treatment,
    sample=sample,
    threshold=1,
)


data_dummy_logit["w (Indicator)"] = D_labels["w"].astype(int)

In [None]:
print(
    r"Post Pruning (using optimal $\pi(x)$ threshold), ATTE: %.3f ± %.3f"
    % (
        100 * df_v["te"].loc[data_dummy_logit["w (Indicator)"] == 1].mean(),
        100 * df_v["te"].loc[data_dummy_logit["w (Indicator)"] == 1].sem(),
    )
)

## Characterizing Underrepresented Population via ROOT

In [None]:
importlib.reload(learn)

# running ROOT
np.random.seed(42)
D_rash, D_forest, w_forest, rashomon_set, f, testing_data = learn.forest_opt(
    data=data_dummy.drop(columns=["w"], errors="ignore"),
    outcome=outcome,
    treatment=treatment,
    sample=sample,
    leaf_proba=0.25,
    num_trees=5000,
    vote_threshold=9 / 10,
)

# calculating a baseline objective value i.e. SE when no region is pruned from the analysis 
baseline_loss = np.sqrt(np.sum(D_forest["vsq"]) / ((D_forest.shape[0] ** 2)))

In [None]:
np.random.seed(42)

### Sorting trees by their objective values (smallest to largest)
local_obj = pd.DataFrame(
    np.array([w_forest[i]["local objective"] for i in range(len(w_forest))]),
    columns=["Objective"],
).sort_values(by="Objective")

### choosing top-k trees and plotting relative obj (Creating a Rashomon Set) -- seeing where the objective value stabilizes
top_k = 25
plt.plot((local_obj.iloc[:top_k] / baseline_loss)["Objective"].values)

w_rash = ["w_tree_%d" % (i) for i in list(local_obj.iloc[:top_k].index)]

In [None]:
avg_votes = D_forest[w_rash].mean(axis=1)  
D_rash["w_opt"] = avg_votes
print(
    r"Post Pruning, ATTE: %.3f ± %.3f"
    % (
        100 * D_rash["v"].loc[D_rash["w_opt"].astype(int) == 1].mean(),
        100 * D_rash["v"].loc[D_rash["w_opt"].astype(int) == 1].sem(),
    )
)

np.random.seed(42)
num_trees = 1
explainer = tree.DecisionTreeClassifier(max_leaf_nodes=8).fit(
    X.drop(columns=["Age"]).loc[avg_votes.index], avg_votes == 1
) # we are not using age here because we inverted categorical age variable to continuous number in target sample, rendering it not interesting for underrepresentation

In [None]:
fig, ax = plt.subplots(nrows=num_trees, figsize=(20, 8), dpi=600)
for i in range(num_trees):
    if num_trees == 1:
        tree.plot_tree(
            explainer,  # .estimators_[i, 0],
            feature_names=[
                "IV Drug Use",
                "Hx Amphetamine",
                "Male",
                "Hx Cannabis",
                "Hx Benzo",
                "White",
                "Black",
                "Hispanic",
                "Other Race",
            ],
            ax=ax,
            filled=True,
            fontsize=10,
            # proportion=True,
            impurity=False,
        )
    else:
        tree.plot_tree(
            explainer.estimators_[i, 0],
            feature_names=[
                "IV Drug Use",
                "Hx Amphetamine",
                "Male",
                "Hx Cannabis",
                "Hx Benzo",
                "White",
                "Black",
                "Hispanic",
                "Other Race",
            ],
            ax=ax[i],
            filled=True,
            fontsize=10,
            # proportion=True,
        )
plt.savefig("tedsa_ctn27.pdf", dpi=600)

## Plotting All results in Selections Score Space

In [None]:
data_dummy_logit["w (ROOT)"] = D_rash["w_opt"].astype(int)
fig, ax = plt.subplots(
    nrows=2, ncols=2, sharex=True, sharey=True, figsize=(20, 8), dpi=600
)

sns.swarmplot(
    data=data_dummy_logit.loc[data_dummy_logit["S"] == 1].astype(float),
    x="pi(x)/pi",
    y="w (predefined threshold)",
    hue="w (predefined threshold)",
    # split=True,
    orient="h",
    ax=ax[0, 0],
    size=5,
    # alpha=0.5,
    # inner="quart",
    # fill=True,
    legend=False,
    hue_order=[1, 0],
)
sns.violinplot(
    data=data_dummy_logit.loc[data_dummy_logit["S"] == 1].astype(float),
    y="w (predefined threshold)",
    hue="w (predefined threshold)",
    x="pi(x)/pi",
    orient="h",
    # split=True,
    alpha=0.25,
    ax=ax[0, 0],
    inner="quart",
    fill=True,
    hue_order=[1, 0],
    legend=False,
)
ax[0, 0].set_ylabel("w")
ax[0, 0].axvline(0.87, ls="--", c="black")
ax[0, 0].set_title("(a) Predefined Defined Threshold on Selection Score")

sns.swarmplot(
    data=data_dummy_logit.loc[data_dummy_logit["S"] == 1].astype(float),
    x="pi(x)/pi",
    y="w (optimal threshold)",
    hue="w (optimal threshold)",
    # split=True,
    orient="h",
    ax=ax[0, 1],
    size=5,
    # alpha=0.5,
    # inner="quart",
    # fill=True,
    legend=False,
    hue_order=[1, 0],
)
sns.violinplot(
    data=data_dummy_logit.loc[data_dummy_logit["S"] == 1].astype(float),
    y="w (optimal threshold)",
    hue="w (optimal threshold)",
    x="pi(x)/pi",
    orient="h",
    # split=True,
    alpha=0.25,
    ax=ax[0, 1],
    inner="quart",
    fill=True,
    hue_order=[1, 0],
    legend=False,
)
ax[0, 1].set_ylabel("w")
ax[0, 1].axvline(a_vals[np.argmin(objs)], ls="--", c="black")
ax[0, 1].set_title("(b) Optimized Threshold on Selection Score")


sns.swarmplot(
    data=data_dummy_logit.loc[data_dummy_logit["S"] == 1].astype(float),
    y="w (Indicator)",
    hue="w (Indicator)",
    x="pi(x)/pi",
    orient="h",
    # split=True,
    # alpha=0.5,
    ax=ax[1, 0],
    size=5,
    # inner="quart",
    # fill=True,
    legend=False,
    hue_order=[1, 0],
)
sns.violinplot(
    data=data_dummy_logit.loc[data_dummy_logit["S"] == 1].astype(float),
    y="w (Indicator)",
    hue="w (Indicator)",
    x="pi(x)/pi",
    orient="h",
    # split=True,
    alpha=0.25,
    ax=ax[1, 0],
    inner="quart",
    fill=True,
    legend=False,
    hue_order=[1, 0],
)
ax[1, 0].set_title("(c) Indicator")
ax[1, 0].set_ylabel("w")


sns.swarmplot(
    data=data_dummy_logit.loc[data_dummy_logit["S"] == 1].astype(float),
    y="w (ROOT)",
    hue="w (ROOT)",
    x="pi(x)/pi",
    orient="h",
    # split=True,
    # alpha=0.5,
    ax=ax[1, 1],
    size=5,
    # inner="quart",
    # fill=True,
    legend=False,
    hue_order=[1, 0],
)
sns.violinplot(
    data=data_dummy_logit.loc[data_dummy_logit["S"] == 1].astype(float),
    y="w (ROOT)",
    hue="w (ROOT)",
    x="pi(x)/pi",
    orient="h",
    # split=True,
    alpha=0.25,
    ax=ax[1, 1],
    inner="quart",
    fill=True,
    hue_order=[1, 0],
)
ax[1, 1].set_title("(d) ROOT")
ax[1, 1].legend(title="w")
ax[1, 1].set_ylabel("w")


# plt.ylabel("w")
# plt.legend(title="w")
plt.xlabel(r"$\ell(x)/\ell$")
plt.tight_layout()
plt.savefig("underrep_root.pdf")


In [None]:
### Percentage of covariate space marked as underrepresented
data_dummy_logit.mean()