# The real-world dataset: NHEFS

In this hands-on session, we will perform causal analyses on a real-world healthcare dataset, known as the *NHANES I Epidemiologic Follow-up Study (NHEFS)* dataset. It is a government-initiated longitudinal study designed to investigate the relationships between clinical, nutritional, and behavioral factors. For more detail, please see see the [CDC webpage](https://wwwn.cdc.gov/nchs/nhanes/nhefs/default.aspx/).

Our main task is to estimate the average effect (ATE) of quitting smoking ($T$) on weight gain ($Y$). The NHEFS cohort includes 1,566 cigarette smokers between 25 - 74 years of age who completed two medical examinations at separate time points: a baseline visit and a follow-up visit approximately 10 years later. Individuals were identified as the treatment group if they reported smoking cessation before the follow-up visit. Otherwise, they were assigned to the control group. Finally, each individual’s weight gain, $Y$, is the difference in *kg* between their body weight at the follow-up visit and their body weight at the baseline visit.

# Our objectives in the session

We aim to cover the following:
1. Learn how to implement propensity score re-weighting to estimate the ATE in Python.
2. Learn how to implement covariate adjustment strategies to estimate the conditional average treatment effect (CATE) as well as ATE in Python.
3. By comparing *naive* estimates to what we have in (1) and (2), observe how confounders, when unadjusted, can introduce bias into the ATE estimate.
4. Learn how to check whether some of the statistical properties required for causal inference are satisfied in the data.
5. Learn how to implement simple and causally motivated decision-making policies on top of CATE estimations.

The remainder of the notebook will guide us through this task and provide the necessary boilerplate code.

In [None]:
# import packages and load the data
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import seaborn as sns
import pandas as pd
import math
import matplotlib.pyplot as plt
from collections import OrderedDict

from google.colab import auth
auth.authenticate_user()

In [None]:
nhefs_all = pd.read_csv("https://github.com/mlhcmit/psets/blob/master/2025/pset3/nhefs.csv?raw=true")

outcome_col = ["wt82_71"]  # weight gain measured after 10 years from the baseline.
treatment_col = ["qsmk"]   # indicator for smoking cessation

# drop samples with missing outcomes for the assignment
# note that this could introduce selection bias and a more principled analysis would account for the censored samples.
missing = nhefs_all[outcome_col].isnull().any(axis=1)
nhefs = nhefs_all.loc[~missing]

# The causal problem and challenges

Let $Y$ denote an outcome of interest and $T \in \{0, 1\}$ denote a binary treatment. Furthermore, let $Y(t)$ denote the potential outcome of an individual under treatment, $T = t$. In randomized experiments, we learned that the average treatment effect (ATE) can be identified as,

\begin{align}
  E[Y(1) - Y(0)] = E[Y|T = 1] - E[Y|T=0],
\end{align}

In other words, in randomized experiments where a treatment is randomly assigned, association admits a *causal* interpretation.

In observational data, however, treatment (i.e. quit smoking) is not randomly assigned, and it is very likely dependent on patient characteristics at the baseline. For example, we can reasonably assume that a proclivity for alcohol and exercise habits may influence one's decision to quit smoking (i.e. treatment) as well as weight gain (i.e. outcome). Therefore, if we fail to adjust for these factors (i.e. confounders), we may introduce bias in our estimate for the causal effect of quitting smoking on weight gain.

As the lecture covered, we will need additional assumptions and adjustment strategies, which you will get a chance to utilize throughout this problem set. We will start with computing the mean difference in outcome between the treatment (quitters; $T = 1$) and control groups (non-quitters; $T = 0$), without any adjustment for potential confounders.

In [None]:
## TODO: Task #1 #############################################################
#                                                                            #
# Compute mean difference in outcomes between treatment (quitters) and       #
# control groups (non-quitters). In other words, E[Y|T=1] - E[Y|T=0].        #
#                                                                            #
##############################################################################

##########################
#                        #
#        SOLUTION        #
#                        #
##########################

...

# Exploring the NHFES data

We have seen one simple example of why we should adjust for confounders in order to estimate the ATE in an unbiased way. In this section, we will use the NHFES dataset to do a simple exploratory analysis on the provided set of confounders and investigate important assumptions we have to make, including overlap (i.e. common support; $0 < P(T = t|X = x) < 1,~\forall t, x$), before adjusting for the set of confounders. Note that throughout the analysis, we'll assume ignorability (i.e. $(Y(0), Y(1)) \perp \! \!\! \! \perp T | X$) unless mentioned otherwise.

We will provide code that generates distributions of confounders across treatment and control groups as well as a summary table, so you can assess how the treatment and control groups differ in their confounder distributions.

# Provided set of confounders and explanations:

1. sex - 0: MALE 1: FEMALE
2. age - age in 1971
3. race - 0: WHITE 1: BLACK OR OTHER IN 1971
4. ht - HEIGHT IN CENTIMETERS IN 1971
5. education - EDUCATION BY 1971: 1: 8TH GRADE OR LESS, 2: HS DROPOUT, 3: HS, 4:COLLEGE DROPOUT, 5: COLLEGE OR MORE
6. alcoholpy - HAVE YOU HAD 1 DRINK PAST YEAR? IN 1971,  1:EVER, 0:NEVER; 2:MISSING
7. smokeintensity - NUMBER OF CIGARETTES SMOKED PER DAY IN 1971
8. smokeyrs - YEARS OF SMOKING
9. wt71 - WEIGHT AT THE BASELINE IN 1971, IN KILOGRAMS
10. exercise - IN RECREATION, HOW MUCH EXERCISE? IN 1971, 0:much exercise, 1:moderate exercise, 2:little or no exercise


# Visualize distributions of confounders across treatment and control groups.

In [None]:
confounders_cols = [
    'sex', 'age', 'race', 'ht', 'education', 'alcoholpy', 'smokeintensity', 'smokeyrs', 'wt71', 'exercise'
]

nhefs_quit = nhefs[nhefs.qsmk == 1]
nhefs_noquit = nhefs[nhefs.qsmk == 0]

plt.figure(figsize=(15,15))
num_bins = 20

for idx, feat in enumerate(confounders_cols):
  plt.subplot(4,3,idx + 1)
  sns.distplot(nhefs_quit[feat], label = 'quitters', bins = num_bins)
  sns.distplot(nhefs_noquit[feat], label = 'non-quitters', bins = num_bins)
  plt.legend()

plt.show()

# Create summary table

In [None]:
nhefs['university'] = (nhefs.education == 5).astype('int')
nhefs['no_exercise'] = (nhefs.exercise == 2).astype('int')

summaries = OrderedDict((
    ('sex', lambda x: (100 * (x == 0)).mean()),
    ('age', 'mean'),
    ('race', lambda x: (100 * (x == 0)).mean()),
    ('ht', 'mean'),
    ('university', lambda x: 100 * x.mean()),
    ('alcoholpy', lambda x: (100 * (x == 1)).mean()),
    ('smokeintensity', 'mean'),
    ('smokeyrs', 'mean'),
    ('wt71', 'mean'),
    ('no_exercise', lambda x: 100 * x.mean()),
))

table = nhefs.groupby('qsmk').agg(summaries)
table.sort_index(ascending=False, inplace=True)
table = table.T

table.index = [
    'Men, %',
    'Age, years',
    'White, %',
    'Height at the baseline, cm',
    'University education, %',
    'Drinks alcohol, %',
    'Cigarettes/day',
    'Years smoking',
    'Weight at the baseline, kg',
    'Little or no exercise, %',
]

# add confidence interval
table_with_ci = pd.DataFrame([], index = table.index, columns = ['treatment (qsmk = 1)', 'control (qsmk = 0)'])

for idx, col_name in zip(table.index, confounders_cols):
  if '%' in idx:
    # treat group
    p_hat = table.at[idx, 1]/100
    se_treat = math.sqrt(p_hat*(1-p_hat)/len(nhefs_quit))*100
    lower_95_ci = table.at[idx, 1] - 1.96*se_treat; upper_95_ci = table.at[idx, 1] + 1.96*se_treat;
    table_with_ci.at[idx, 'treatment (qsmk = 1)'] = '{0:>0.1f} - {1:>0.1f}'.format(lower_95_ci, upper_95_ci)

    # control group
    p_hat = table.at[idx, 0]/100
    se_control = math.sqrt(p_hat*(1-p_hat)/len(nhefs_noquit))*100
    lower_95_ci = table.at[idx, 0] - 1.96*se_treat; upper_95_ci = table.at[idx, 0] + 1.96*se_treat;
    table_with_ci.at[idx, 'control (qsmk = 0)'] = '{0:>0.1f} - {1:>0.1f}'.format(lower_95_ci, upper_95_ci)

  else:
    # treat group
    std = nhefs_quit[col_name].std()
    se_treat = std/math.sqrt(len(nhefs_quit))
    lower_95_ci = table.at[idx, 1] - 1.96*se_treat; upper_95_ci = table.at[idx, 1] + 1.96*se_treat;
    table_with_ci.at[idx, 'treatment (qsmk = 1)'] = '{0:>0.1f} - {1:>0.1f}'.format(lower_95_ci, upper_95_ci)

    # control group
    std = nhefs_noquit[col_name].std()
    se_control = std/math.sqrt(len(nhefs_noquit))
    lower_95_ci = table.at[idx, 0] - 1.96*se_treat; upper_95_ci = table.at[idx, 0] + 1.96*se_treat;
    table_with_ci.at[idx, 'control (qsmk = 0)'] = '{0:>0.1f} - {1:>0.1f}'.format(lower_95_ci, upper_95_ci)

print('95% confidence intervals are shown')
table_with_ci

## Questions :

1. Does the provided set of confounders make sense? Pick a couple confounders and discuss why they qualify as confounders.

2. Based on the provided plots and table, can you conclude the overlap assumption is satisfied? Why or why not? What would be an implication of a high-dimensional set of confounders on the overlap assumption?


# Propensity Score Re-weighting

We have seen how failing to adjust for the set of confounders can introduce bias in our causal estimate. As we learned in class, propensity score reweighting is one of the most widely used adjustment methods for causal inference. Given that the ignorability and overlap assumptions are satisfied and that our propensity prediction model is correctly specified, scaling each outcome by the inverse of the corresponding propensity score creates a pseudo-population where the treatment assignment is effectively random. More formally, our estimated ATE, $\hat{\theta}$, is derived as follows:

\begin{align}
\hat{\theta} = \frac{1}{N}\sum_{i ~\text{s.t.}~ t_i = 1}\frac{y_i}{\hat{P}(T = 1|X = x_i)} - \frac{1}{N}\sum_{i ~\text{s.t.}~ t_i = 0}\frac{y_i}{\hat{P}(T = 0|X = x_i)}
\end{align}

In this part, you will use sklearn's [LogisticRegressionCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegressionCV.html) to estimate the probability of treatment given the set of confounders (i.e. $\hat{P}(T = t|X)$). Note that CV stands for "cross-validation" and it automatically handles choosing the best regularization parameter when fitting the Logistic Regression model.

As an aside, you would actually want to do *cross-fitting* to avoid bias due to overfitting and estimate propensity scores. If you are interested in learning more, please see [here](https://cdn1.sph.harvard.edu/wp-content/uploads/sites/1268/2021/03/ciwhatif_hernanrobins_30mar21.pdf) (particularly Chapter 18-4).

In [None]:
## TODO: Task #2 #############################################################
#                                                                            #
# Estimate propensity scores for both treatment and control groups. For the  #
# LogisticRegression hyperparameters, Use LogisticRegressionCV from sklearn  #
# to choose the best L1 penalty coefficient.  Use "fit" and "predict_proba"  #
# functions to fit your model and use it to predict propensity scores.       #
#                                                                            #
##############################################################################

from sklearn.linear_model import LogisticRegressionCV

##########################
#                        #
#        SOLUTION        #
#                        #
##########################

...

## Questions :
3. Obtain the coefficients for confounders in your propensity estimation model (use `model.coef_` to obtain coefficients from the trained model). Do the coefficients make sense? Choose coefficients for a couple of confounders and discuss.

Now that we have estimated propensity scores, we can visualize the empirical distribution of our estimated propensity scores and gain additional insights on the validity of the overlap condition. We have already provided the code for the visualization, and all you have to do is to assign your estimated propensity scores to the `propensity_scores` variable.


In [None]:
sns.distplot(propensity_scores, bins = num_bins)
plt.legend()
plt.xlabel('propensity scores')
plt.figure(figsize=(10,5))

## Questions :
4. Let us assume that our propensity score is correctly specified, what does the above plot tell you about whether or not the overlap assumption is met?
5. What advantage does this approach offer over comparing the marginal distributions of the confounders?

Now, we are going to use the propensity re-weighting method to estimate average treatment effect (ATE), $\hat{\theta}$, using the formula provided above. We restate it below.

\begin{align}
\widehat{\text{ATE}} = \frac{1}{N}\sum_{i ~\text{s.t.}~ t_i = 1}\frac{y_i}{\hat{P}(T = 1|X = x_i)} - \frac{1}{N}\sum_{i ~\text{s.t.}~ t_i = 0}\frac{y_i}{\hat{P}(T = 0|X = x_i)}
\end{align}

where $N$ is the size of the entire cohort.

In [None]:
## TODO: Task #3 #############################################################
#                                                                            #
# Apply propensity re-weighting using the provided formula above and         #
# estimate ATE.                                                              #
#                                                                            #
##############################################################################

##########################
#                        #
#        SOLUTION        #
#                        #
##########################

...

### Covariate Adjustment

Covariate adjustment is another popular strategy where one explictly models the relationships between the outcome $Y$, treatment $T$ and a set of confounders $X$. With the ignorability and overlap assumptions met, a correctly specified model provides an unbiased estimate of the average treatment effect (ATE). In this section, you will use sklearn's [LinearRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html) model to obtain
\begin{align}
  \hat{E}[Y|T = t, X = x].
\end{align}

As we learned in class, when the identifiability assumptions are met, we can relate the data to the potential outcomes. Namely, $\hat{E}[Y(t)|X = x] = \hat{E}[Y|T = t, X = x]$. Therefore, we have
\begin{align}
  \widehat{\text{CATE}} &= \hat{E}[Y(1)|X = x] - \hat{E}[Y(0)|X = x] \\
  &= \hat{E}[Y|T = 1, X = x] - \hat{E}[Y|T = 0, X = x] \\
\end{align}
Then, $\widehat{\text{ATE}}$ can be estimated as
\begin{align}
   \hat{\theta} &= \hat{E}_{x \sim p(x)}[\widehat{\text{CATE}}] \\
  &= \frac{1}{N} \sum_{i}^N \hat{E}[Y|T = 1, X = x_i] - \hat{E}[Y|T = 0, X = x_i],
\end{align} where $N$ is the size of the entire cohort.

In [None]:
## TODO: Task #4 #############################################################
#                                                                            #
# Use covariate adjustment and estimate the ATE.                             #
# Use sklearn's LinearRegression package. Use "fit" function.                #
#                                                                            #
##############################################################################

from sklearn.linear_model import LinearRegression

##########################
#                        #
#        SOLUTION        #
#                        #
##########################

...

## Questions :
6. In the linear model, the coefficient for the treatment variable $T$ corresponds to $\hat{\theta}$ (the ATE estimate). Can you explain why?

7. Report your estimated ATEs from the propensity score re-weighting and covariate adjustment methods. How are they different from the mean difference of outcomes between the treatment and control groups without adjusting for confounders (from task #1)? If there is a difference, does its sign make sense (i.e., whether the naive estimate is an under or over estimation)?

8. Do you think our model for the covariate adjustment task (i.e. linear regression) is valid? What are some potential limitations of this linear model approach?

# Clinical Decision-Making on the Basis of Conditional (Individualized) Average Treatment Effects (CATE)

In practice, we would like to make clinical decision not on the basis of the ATE, but on the CATE. In the next synthetic toy example, we will see how such a decision-making rule can be implemented on the basis of CATE.

## Synthetic Data Generation Process

## Overview
We will create a realistic clinical dataset where patients choose between two blood pressure medications. The data generation process mimics real-world clinical scenarios with confounding, complex treatment effects, and realistic patient characteristics.

---

## Patient Characteristics (Confounders)

### Baseline Variables
We generate 5 key patient characteristics that influence both treatment selection and outcomes:

| Variable | Distribution | Range | Clinical Meaning |
|----------|-------------|-------|------------------|
| **Age** | Normal(65, 15) | 30-90 years | Patient age |
| **Baseline BP** | Normal(160, 20) | 120-220 mmHg | Systolic blood pressure before treatment |
| **BMI** | Normal(28, 5) | 18-45 kg/m² | Body mass index |
| **Diabetes** | Bernoulli(0.3) | 0/1 | Diabetes status (30% have diabetes) |
| **Kidney Function** | Normal(60, 20) | 15-120 mL/min | Estimated glomerular filtration rate (eGFR) |


In [None]:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.tree import export_text
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
np.random.seed(42)


def generate_synthetic_data(n=2000):
    ## TODO: Task #1 ###################################################################
    #                                                                                  #
    # Generate synthetic patient cohort following the distributions in the table above #
    # Use numpy                                                                        #
    #                                                                                  #
    ####################################################################################

    ##########################
    #                        #
    #        SOLUTION        #
    #                        #
    ##########################

    age = ...
    age = np.clip(age, 30, 90)

    baseline_bp = ...

    bmi = ...

    kidney_function = ...

    diabetes = ...

    # Create DataFrame
    data = pd.DataFrame({
        'age': age,
        'baseline_bp': baseline_bp,
        'bmi': bmi,
        'diabetes': diabetes,
        'kidney_function': kidney_function
    })

    # Save normalized versions
    data['age_norm'] = (data['age'] - 65) / 15
    data['bp_norm'] = (data['baseline_bp'] - 160) / 20
    data['bmi_norm'] = (data['bmi'] - 28) / 5
    data['kidney_norm'] = (data['kidney_function'] - 60) / 20

    return data

## A ground-truth CATE function and treatment assignment strategy

Below, we define a "complex" ground-truth CATE function and treatment assignment strategy, that depend on all the confounders generated above.

Please take a moment to investigate these functions to understand how different confouders affect the effectiveness of the treatment *and* the probability of treatment assignment by the clinician.

In [None]:
def true_cate_function(data):
    """
    Define the ground truth CATE function.
    This represents the true difference in treatment effect between new and standard medication.
    """

    age_norm = data['age_norm']
    bp_norm = data['bp_norm']
    bmi_norm = data['bmi_norm']
    diabetes = data['diabetes']
    kidney_norm = data['kidney_norm']

    # Complex CATE function with interactions and non-linearities
    cate = (
        # Base effect: new drug works better for older patients with higher BP
        5 * age_norm * bp_norm +

        # BMI interaction: new drug less effective for very high BMI
        -3 * np.maximum(0, bmi_norm - 1)**2 +

        # Diabetes interaction: new drug much better for diabetics with kidney issues
        8 * diabetes * np.maximum(0, -kidney_norm) +

        # Age-kidney interaction: new drug worse for elderly with good kidney function
        -4 * np.maximum(0, age_norm) * np.maximum(0, kidney_norm) +

        # Non-linear baseline BP effect
        3 * np.sin(bp_norm * np.pi) +

        # Threshold effect for very high baseline BP
        6 * (bp_norm > 1.5).astype(int) +

        # Random noise to make it more realistic
        np.random.normal(0, 1, len(data))
    )

    return cate


def generate_treatment_assignment(data, cate):
    """
    Generate treatment assignment with confounding.
    We simulate a case where doctors tend to give new medication to patients in more adversarial conditions.
    """

    # Propensity (i.e., treatment assignment score)

    propensity_logit = (
        0.5 * data['bp_norm'] +  # Higher BP more likely to get new drug
        0.3 * data['diabetes'] +  # Diabetics more likely to get new drug
        -0.2 * data['age_norm'] +  # Slightly less likely for very elderly
        0.1 * data['bmi_norm']  # Higher BMI more likely to get new drug
    )

    propensity = 1 / (1 + np.exp(-propensity_logit))
    treatment = np.random.binomial(1, propensity, len(data))

    return treatment, propensity


def generate_outcomes(data, treatment, cate):
    """
    Generate potential outcomes and observed outcomes.
    """
    # Potential outcome for control group (standard medication)
    y0_base = (
        5 +  # Base reduction in the Systolic Blood Pressure (the outcome of interest)
        np.random.normal(0, 1, len(data))  # Random noise
    )

    # Ground-truth potential outcomes
    y0 = y0_base         # Standard medication
    y1 = y0_base + cate  # New medication = Standard medication + CATE

    # Observed outcome
    y_observed = treatment * y1 + (1 - treatment) * y0

    return y0, y1, y_observed

## Generate synthetic data

Now we have defined all our confounding variables, and how the potential outcomes, CATE, and the treatment assignment depend on them, we go ahead and sample a cohort.

In [None]:
# Generate the synthetic dataset
print("Generating synthetic clinical data...")
data = generate_synthetic_data(n=2000)

# Generate true CATE
true_cate = true_cate_function(data)
data['true_cate'] = true_cate

# Generate treatment assignment
treatment, propensity = generate_treatment_assignment(data, true_cate)
data['treatment'] = treatment
data['propensity'] = propensity

# Generate outcomes
y0, y1, y_observed = generate_outcomes(data, treatment, true_cate)
data['y0'] = y0  # Potential outcome under control
data['y1'] = y1  # Potential outcome under treatment
data['outcome'] = y_observed  # Observed outcome

print(f"Synthetic dataset created with {len(data)} patients")

## Optimal Treatment Decisions

For simplicity, we assume that a higher outcome (reduction in systolic blood pressure) is always better.

Therefore, when CATE is greater than zero, we would like to assign that patient to the new medicine (treatment). When the CATE is smaller than zero we would assign the patient to the standard medicine (control).

In [None]:
## TODO: Task #2 ###################################################################
#                                                                                  #
# Create a column for the optimal treatment decisions for each patient             #
# This should be based on the value of "true_cate"                                 #
#                                                                                  #
####################################################################################

##########################
#                        #
#        SOLUTION        #
#                        #
##########################

data['optimal_treatment'] = ...

## Estimating the CATE via covariate adjustment

We will first split the data into two equally sized chunks, and call them "train" and "test" splits out of convenience.

We will use the "train" split to fit outcome models to estimate the CATE function (covariate adjustment).

We will use the [RandomForestRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html) from sklearn to fit these outcome models functions.

In [None]:
# Split data
features = ['age', 'baseline_bp', 'bmi', 'diabetes', 'kidney_function']
X = data[features]
y = data['outcome']
treatment = data['treatment']

X_train, X_test, y_train, y_test, t_train, t_test = train_test_split(
    X, y, treatment, test_size=0.3, random_state=42
)

# Get corresponding true CATE and optimal treatment decision for test set
true_cate_test = data.loc[X_test.index, 'true_cate']
optimal_treatment_test = data.loc[X_test.index, 'optimal_treatment']

## TODO: Task #3 ###################################################################
#                                                                                  #
# Fit separate outcome models for the control and the treatment groups.            #
# Use the train split                                                              #
#                                                                                  #
####################################################################################

##########################
#                        #
#        SOLUTION        #
#                        #
##########################

control_mask = t_train == 0    # indices of the control group (standard medicine) in the training set
treatment_mask = t_train == 1  # indices of the treatment group (new medicine) in the training set


model_control = ...  # initiate the RandomForestRegressor model for the control group
...  # fit the model

model_treatment = ...  # initiate the RandomForestRegressor model for the treatment group
...  # fit the model

## CATE-based treatment decisions

For simplicity, we assume that a higher reduction in the sytolic blood pressure is always better, and we want to prefer the treatment that reduces it more.

Since in this example, we only care about the "sign" of the CATE function, we will make our treatment decisions based on the sign of the CATE function solely.

A patient will get the new medicine (treatment=1) if their estimated CATE is greater than zero, and the standard medicine (treatment=0) otherwise.

In [None]:
## TODO: Task #4 ########################################################################
#                                                                                       #
# Estimate CATE in "train" split using the difference in predictions of outcome models  #
# Make treatment decisions based on the sign of the CATE                                #
#                                                                                       #
#########################################################################################

##########################
#                        #
#        SOLUTION        #
#                        #
##########################

# Predict outcomes under both treatments for test set (use X_train)
y0_pred_train = ...
y1_pred_train = ...

# Estimate CATE in the train split
cate_estimated_train = ...
cate_treatment_recommendation_train = ...

# repeat the steps above for the test-split for evaluation later

# Predict outcomes under both treatments for test set (use X_test)
y0_pred_test = ...
y1_pred_test = ...

# Estimate CATE in the test split
cate_estimated_test = ...
cate_treatment_recommendation_test = ...

## Fit a decision-tree in the test set based on CATE estimates

CATE functions in reality can be complex. In clinical decision-making, having simpler models that can be interpreted easily have advantages.

In the previous parts, we fitted two outcome models on the "train" split and used them to estimate the CATE for individuals in the test split.

In this step, we will fit a [DecisionTree](https://scikit-learn.org/stable/modules/tree.html) on the "train" split to predict the treatment assignments made based on the CATE function, which could result in a simpler model.

In [None]:
## TODO: Task #5 #####################################################################
#                                                                                    #
# Fit a decision tree in the train split to predict CATE-based treatment assignments #
# Set max_depth=3 for easier interpretation                                          #
# Note that there is a tradeoff with performance when we use lower depth.            #
#                                                                                    #
# Get treatment assignment predictions from the tree once fitted in the "test" split #
#                                                                                    #
######################################################################################


##########################
#                        #
#        SOLUTION        #
#                        #
##########################

decision_tree = ...  # instantiate the DecisionTree
... # fit the model

tree_predictions_test = ...  # CATE-sign predictions from the tree

## Analyzing the quality of decisions made by the tree based system and visualizing our tree-based rule

Compare the performance of the tree-based strategy to that of optimal strategy and CATE-based strategy.

## Questions

1. Do these findings make sense (i.e., the relative ranking of the performance of different strategies)?
2. What would it take the tree-based strategy's performance to match that of CATE-based strategy?
3. What would it take the CATE-based strategy's performance to match that of optimal strategy?

In [None]:
print("\nTree-based recommendation performance per treatment class:")
print(classification_report(optimal_treatment_test, tree_predictions_test,
                          target_names=['Standard Med', 'New Med']))

# Compare different strategies
strategies = {
    'Always Standard': np.zeros(len(X_test)),
    'Always New': np.ones(len(X_test)),
    'Random': np.random.binomial(1, 0.5, len(X_test)),
    'CATE-based': cate_treatment_recommendation_test,
    'Decision Tree': tree_predictions_test,
    'True Optimal': optimal_treatment_test
}

results = {}
y0_test = data.loc[X_test.index, 'y0']
y1_test = data.loc[X_test.index, 'y1']

for strategy_name, recommendations in strategies.items():
    # Calculate expected outcome under this strategy
    outcomes = recommendations * y1_test + (1 - recommendations) * y0_test
    avg_outcome = np.mean(outcomes)

    # Calculate how often we make the right decision
    correct_decisions = np.mean(recommendations == optimal_treatment_test)

    results[strategy_name] = {
        'avg_outcome': avg_outcome,
        'correct_decisions': correct_decisions
    }

print("\nStrategy Comparison:")
print("-" * 60)
print(f"{'Strategy':<15} {'Avg Outcome':<12} {'Correct %':<10}")
print("-" * 60)
for strategy, metrics in results.items():
    print(f"{strategy:<15} {metrics['avg_outcome']:<12.2f} {metrics['correct_decisions']:<10.1%}")


strategy_names = list(results.keys())
avg_outcomes = [results[s]['avg_outcome'] for s in strategy_names]

plt.figure()

plt.bar(range(len(strategy_names)), avg_outcomes, alpha=0.8)
plt.xlabel('Strategy')
plt.ylabel('Average Outcome (BP Reduction)')
plt.title('Treatment Strategy Comparison')
plt.xticks(range(len(strategy_names)),strategy_names, rotation=45, ha='right')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Display the decision tree rules
print("\nDecision Tree Rules for Clinical Use:")
print("=" * 50)

tree_rules = export_text(decision_tree, feature_names=features, max_depth=4)
print(tree_rules)