# Homework 5, Question 3 (40 points)
In this question, we explore the R.O.A.D. methodology. We will consider `spleen_mlopt_final_train.csv` and `spleen_mlopt_final_test.csv`, a subset of data on splenectomy that consists of 2,400 individuals. The features, treatment, and outcome variables are as follows:
- Features: `sex`, `age`, `sbp`, `pulserate`, `respiratoryrate`, `pulseoximetry`, `totalgcs`, `intubated`, `bmi`, etc.
- Treatment: `treatment`, which could either be "splenectomy" (spleen removal surgery, e.g. treatment) or "observation" (control - e.g. no surgery)
- Outcome: `o_mortality`, where 1 indicates patient death ("expiration") and 0 otherwise

*__Important:__* Please note to use the seed provided in all places (data splitting and any tree training), and do not change anything regarding the order/columns in the `spleen.csv` dataset, unless otherwise specified. 

In [1]:
# Load packages
# If you need to install packages, please do not leave the output of installation in your homework submission.
using CSV, DataFrames, CategoricalArrays, Plots, Statistics, Random, StatsPlots, Gurobi, JuMP

seed = 42;

#### Read in the data and get the features, treatment, and outcome

In [2]:
train = CSV.read("spleen_mlopt_final_train.csv", DataFrame)
test = CSV.read("spleen_mlopt_final_test.csv", DataFrame)

X_train = train[:, Not([:treatment, :o_mortality])]
t_train = train[:, :treatment]
y_train = train[:, :o_mortality]

X_test = test[:, Not([:treatment, :o_mortality])]
t_test = test[:, :treatment]
y_test = test[:, :o_mortality];

## Part 1
In this part, we will perform the first main step of R.O.A.D., which removes observed confounding by selecting a matched dataset from the original set of data. 

### Part 1(a) - 5 points

Using the *__control patients in the training dataset__*, train a `RandomForest` via `GridSearch` to predict mortality outcome; use cross-validation with 5 folds and criterion of AUC to finetune the hyperparameters amongst the following possible values: `max_depth` of [5], `minbucket` of [20, 50], and `num_trees` of [50, 100]. This model is trained to estimate mortality outcomes if patients did not receive treatment. 

Please answer the following questions.
- How good is our estimator? That is, what is the AUC and accuracy of our estimator on the control patients in the test set?
- Use this model to predict what the the mortality risk (e.g. the probability of mortality) would be if all the training set patients did not receive treatment. What is the mortality risk for the last training set patient (patient \#1920)? 

In [3]:
idx_train_control = findall(val -> val == "observation", t_train)
X_train_control = X_train[idx_train_control, :]
t_train_control = t_train[idx_train_control]
y_train_control = y_train[idx_train_control]

idx_test_control = findall(val -> val == "observation", t_test)
X_test_control = X_test[idx_test_control, :]
t_test_control = t_test[idx_test_control]
y_test_control = y_test[idx_test_control];

In [4]:
# Train RF on control training data
depths = [5]
minbuckets = [20, 50]
num_trees = [50, 100]

# Grid Search for Random Forest
rf_grid = IAI.GridSearch(
    IAI.RandomForestClassifier(
      random_seed=seed,
      # criterion=:gini
    ),
    max_depth=depths,
    minbucket=minbuckets,
    num_trees=num_trees,
  )

  IAI.fit_cv!(
    rf_grid,
    X_train_control,
    y_train_control;
    n_folds = 5,
    validation_criterion = :auc,
)

# Get best model
lnr = IAI.get_learner(rf_grid)



Fitted RandomForestClassifier

In [5]:
# AUC
test_auc = IAI.score(rf_grid, X_test_control, y_test_control, criterion=:auc)
# Accuracy
test_acc = IAI.score(rf_grid, X_test_control, y_test_control, criterion=:accuracy, positive_label=1)

println("Test AUC score: ", round(test_auc, digits=4))
println("Test Accuracy: ", round(test_acc, digits=4))


Test AUC score: 0.924
Test Accuracy: 0.9537


In [6]:
# Predict class probabilities for all training patients under "no treatment" model
probs_train_no_treat = IAI.predict_proba(lnr, X_train)

# Get the label order used by the model (same as in y_train_control)
labels    = unique(y_train_control)
pos_label = 1 
pos_idx   = findfirst(==(pos_label), labels)

mortality_risk_last = probs_train_no_treat[end, pos_idx]

println("Mortality risk for the last training set patient under no treatment: ",
        round(mortality_risk_last, digits = 4))

Mortality risk for the last training set patient under no treatment: 0.2081


### Part 1(b) - 2.5 points

We now divide the training test patients into 5 equally-sized buckets, in order of increasing predicted mortality risk. How many patients received treatment in Bucket 5? You should fill in the TODO section to print this.

In [7]:
# 1. Get mortality risk for each training patient (prob of death class)
risk = probs_train_no_treat[:, pos_idx]  

# 2. Sort patients by increasing risk, and keep their indices
order = sortperm(risk)   # order[i] = original row index of the i-th lowest risk patient
n = length(order)

# 3. Decide bucket size (≈ equal-sized buckets)
num_buckets = 5
bucket_size = ceil(Int, n / num_buckets)  

# 4. Build the 5 buckets following your idea
buckets = Vector{Vector{Int}}()  # each bucket is a vector of row indices

for b in 1:num_buckets
    # positions in the *sorted* order that belong to this bucket
    start_pos = (b - 1) * bucket_size + 1
    end_pos   = min(b * bucket_size, n)   # last bucket may be smaller

    # corresponding original row indices from `train` / `X_train`
    bucket_indices = order[start_pos:end_pos]
    push!(buckets, bucket_indices)
end

# 5. Take Bucket 5 (highest risk) and count treated patients there
bucket5_indices     = buckets[end]                 # indices of patients in bucket 5
bucket5_treatments  = train.treatment[bucket5_indices]
treated_in_bucket5  = sum(bucket5_treatments .== "splenectomy")

println("In bucket 5, ", treated_in_bucket5, " patients received treatment")

In bucket 5, 384 patients received treatment


### Part 1(c) - 2.5 points

For each bucket $k\in 1,...,5$, we compute a corresponding distance matrix $\textbf{D}_k$, such that $\textbf{D}_k[i,j]$ is the squared Euclidean norm between the $i$-th treatment group individual and the $j$-th control group individual. The distance should be computed on three features: [`sex`, `age`, `bmi`]. Make sure to normalize the non-binary features before computing the distance, by subtracting the full training mean and dividing by the full training range, such that they fall between 0 and 1. Formally, let $x_{i,s}$ be the value of the feature $s$ for individual $i$; then the normalized value $\hat{x}_{i,s}$ is defined:
        $$\hat{x}_{i,s} = \frac{x - \min_{j\in\mathcal{T}}(x)}{\max_{j\in\mathcal{T}}(x) - \min_{j\in\mathcal{T}}(x)}, \quad \forall i\in\mathcal{T}, \forall s,$$
 where $\mathcal{T}$ is the set of training datapoints.
        
What is the dimension of $\mathbf{D}_{5}$ and what is the distance between treatment individual 10 and control individual 12 in the 5th bucket?

In [8]:
# Create groups: each group is a DataFrame containing the rows from train corresponding to each bucket
groups = [train[bucket_indices, :] for bucket_indices in buckets]

5-element Vector{DataFrame}:
 [1m1551×36 DataFrame
[1m  Row │[1m sex   [1m age   [1m sbp     [1m pulserate [1m respiratoryrate [1m pulseoximetry [1m tota ⋯
      │[90m Int64 [90m Int64 [90m Float64 [90m Float64   [90m Float64         [90m Float64       [90m Floa ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │     1     26    116.0       87.0             23.0           99.0       ⋯
    2 │     0     38    118.0       84.0             18.0          100.0
    3 │     1     24    120.0       93.0             22.0           99.0
    4 │     1     28    135.0       81.0             19.0           99.0
    5 │     1     26    132.0       68.0             17.0           98.0       ⋯
    6 │     1     30    118.0      107.0             22.0          100.0
    7 │     1     24    130.0      106.0             18.0          100.0
    8 │     1     30    170.0      110.0             22.0          100.0
    9 │     1     29    137.0       90

In [9]:
print(X_train_control.sex[1:5])
print(X_train.age[1:5])
print(X_train_control.bmi[1:5])

print("\n")
print(train.treatment[1:5])

[1, 1, 1, 1, 0][61, 78, 33, 40, 32][19.074528, 23.155415, 20.597332, 38.532974, 31.43062]
String15["observation", "splenectomy", "observation", "observation", "observation"]

In [10]:
function compute_dist(row1, row2, features)
    dist = 0
    for f in features
        dist += (row1[f] - row2[f])^2
    end
    return dist
end

function compute_dist_matrix(df_treatment1, df_treatment2, features)
    n1 = size(df_treatment1)[1]
    n2 = size(df_treatment2)[1]
    distances = zeros(n1, n2)
    for i in 1:n1
        for j in 1:n2
            distances[i,j] = compute_dist(df_treatment1[i, :], df_treatment2[j, :], features)
        end
    end
    return distances
end

min_age = minimum(train[:, "age"])
max_age = maximum(train[:, "age"])
min_bmi = minimum(train[:, "bmi"])
max_bmi = maximum(train[:, "bmi"])

normalize_continuous_feature(value, min_val, max_val) = (value - min_val) / (max_val - min_val) # we define a normalization function

matching_features = ["sex", "age", "bmi"]
group_dist = []
std_gs = []
for g in groups
    std_g = copy(g)[:, vcat(matching_features, ["treatment"])]

    # [TODO: normalize non-binary features - you may use the min/max values computed above]
    std_g.age .= normalize_continuous_feature.(std_g.age, min_age, max_age)
    std_g.bmi .= normalize_continuous_feature.(std_g.bmi, min_bmi, max_bmi)

    # [TODO: call the compute_dist_matrix function with the appropriate arguments to compute the distance matrix for each group]
    std_g_control = filter(:treatment => ==("observation"), std_g)
    std_g_treated = filter(:treatment => ==("splenectomy"), std_g)

    dist_matrix =  compute_dist_matrix(std_g_treated, std_g_control, matching_features)
    
    push!(group_dist, dist_matrix)
end

In [11]:
println("Shape of bucket 5: ", size(group_dist[5]))
println("Distance b/w treatment 10 and control 12 in bucket 5: ", group_dist[5][10, 12])

Shape of bucket 5: (384, 1165)
Distance b/w treatment 10 and control 12 in bucket 5: 1.2509047013624532


### Part 1(d) - 7 points

For each bucket $k\in 1,...,5$, implement and run the matching algorithm from lecture:
        $$\begin{aligned}
            \min_{\boldsymbol{z}} & \quad \sum_{i\in\mathcal{S}_1^k}\sum_{j\in\mathcal{S}_0^k} z_{i,j} \textbf{D}_k[i,j] \\
            \text{s.t.}& \quad \sum_{j\in\mathcal{S}_0^k}z_{i,j} = 1, \quad \forall i \in \mathcal{S}_1^k \\
            & \quad \sum_{i\in\mathcal{S}_1^k}z_{i,j} \leq 1, \quad \forall j \in \mathcal{S}_0^k \\
            &\quad z_{i,j}\in\{0,1\} \quad \forall i \in\mathcal{S}_1^k,~j\in\mathcal{S}_0^k,
        \end{aligned}$$
where ${S}_0^k, {S}_1^k$ is the set of control patients and the set of treatment patients in bucket $k$ respectively. Report the control patients that are matched in the first bucket. You may report them by index.

In [12]:
function matching(dist_matrix)

    m = Model(Gurobi.Optimizer)
    set_optimizer_attribute(m, "OutputFlag", 0)

    @variable(m, z[1:size(dist_matrix, 1), 1:size(dist_matrix, 2)], Bin)
    @objective(m, Min, sum(dist_matrix[i,j] * z[i,j] for i in 1:size(dist_matrix, 1), j in 1:size(dist_matrix, 2)))
    @constraint(m, [i=1:size(dist_matrix, 1)], sum(z[i,j] for j in 1:size(dist_matrix, 2)) == 1)
    @constraint(m, [j=1:size(dist_matrix, 2)], sum(z[i,j] for i in 1:size(dist_matrix, 1)) <= 1)
    
    optimize!(m);
    assignment = value.(z)
    return objective_value(m), assignment
    
end

matching (generic function with 1 method)

In [13]:
data_list = []
for i in 1:length(groups)
    g_dist = group_dist[i]
    obj, assignment = matching(g_dist)
    g = groups[i]
    control = filter(row -> row.treatment == "observation", g)
    treatment = filter(row -> row.treatment == "splenectomy", g)
    
    control_matched_idx = [idx[2] for idx in findall(val -> val == 1, assignment)]
    control_matched = control[control_matched_idx, :]
    push!(data_list, treatment)
    push!(data_list, control_matched)

    # [TODO: print the control patients that are matched in the first bucket]
    if i == 1
        println("Matched control patients in bucket 1:")
        for r in eachrow(control_matched)
            println(r)
        end
    end
    
end

Set parameter Username
Academic license - for non-commercial use only - expires 2026-08-20
Matched control patients in bucket 1:
[1mDataFrameRow
[1m Row │[1m sex   [1m age   [1m sbp     [1m pulserate [1m respiratoryrate [1m pulseoximetry [1m totalgcs [1m intubated [1m bmi     [1m cc_bleeding [1m cc_chf  [1m cc_smoking [1m cc_renal [1m cc_cva  [1m cc_diabetes [1m cc_mi   [1m cc_pad  [1m cc_hypertension [1m cc_copd [1m cc_steroid [1m cc_cirrhosis [1m transf_rbc_1hr [1m transf_wholeblood_1hr [1m liver_inj [1m kidney_inj [1m smallbowel_inj [1m colon_inj [1m spine_inj [1m pelvic_fx [1m tbi   [1m spleen_grade_0 [1m spleen_grade_1 [1m spleen_grade_2 [1m spleen_grade_3 [1m treatment   [1m o_mortality
     │[90m Int64 [90m Int64 [90m Float64 [90m Float64   [90m Float64         [90m Float64       [90m Float64  [90m Int64     [90m Float64 [90m Float64     [90m Float64 [90m Float64    [90m Float64  [90m Float64 [90m Float64     [90m Float64 [9

## Part 2

We now perform the second main step of R.O.A.D., which aims to remove unobserved confounding. Recall that the first step of R.O.A.D, implemented in Part 2, produces matched pairs of treatment and control patients. We can combine these matched pairs across buckets to get a "cleaner'" subset of the full dataset ("removed" observed confounding). We will call this the matched dataset. We now train two estimators -- one for the control group and one for the treatment group -- to perform counterfactual estimation, while finetuning a  parameter $\rho$ to remove the unobserved confounding.

We begin by combine the matched pairs across buckets to produce the matched dataset.

In [14]:
matched_data = reduce(vcat, data_list)
first(matched_data, 5)

Row,sex,age,sbp,pulserate,respiratoryrate,pulseoximetry,totalgcs,intubated,bmi,cc_bleeding,cc_chf,cc_smoking,cc_renal,cc_cva,cc_diabetes,cc_mi,cc_pad,cc_hypertension,cc_copd,cc_steroid,cc_cirrhosis,transf_rbc_1hr,transf_wholeblood_1hr,liver_inj,kidney_inj,smallbowel_inj,colon_inj,spine_inj,pelvic_fx,tbi,spleen_grade_0,spleen_grade_1,spleen_grade_2,spleen_grade_3,treatment,o_mortality
Unnamed: 0_level_1,Int64,Int64,Float64,Float64,Float64,Float64,Float64,Int64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Float64,Float64,Float64,Float64,String15,Int64
1,0,26,116.0,86.0,20.0,100.0,15.0,0,21.1073,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,1.0,0.0,0.0,0.0,splenectomy,0
2,1,33,138.0,97.0,18.0,98.0,15.0,0,23.5685,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,1.0,0.0,0.0,0.0,splenectomy,0
3,0,32,126.0,71.0,16.0,98.0,15.0,0,21.6713,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,splenectomy,0
4,1,18,145.0,100.0,18.0,100.0,15.0,0,19.7945,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,1.0,0.0,0.0,0.0,splenectomy,0
5,1,30,128.0,85.0,16.0,100.0,15.0,0,25.4571,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,splenectomy,0


### Part 2(a) - 2.5 points

On the control patients in the matched dataset, train a `RandomForest` via `GridSearch` to predict the outcome; use cross-validation to finetune the hyperparameters amongst the same possible values as in Part 1(a). This is training a *__control__* model; we denote the predicted mortality risk from this model for patient $\boldsymbol{x}_i$ as $h_{t=0}(\boldsymbol{x}_i)$. 
     
How good is our estimator? That is, what is the AUC and accuracy of our estimator on the control patients in the test set? 

In [15]:
matched_data_control = filter(row -> row.treatment == "observation", matched_data)
matched_data_treatment = filter(row -> row.treatment == "splenectomy", matched_data);

matched_data_control

Row,sex,age,sbp,pulserate,respiratoryrate,pulseoximetry,totalgcs,intubated,bmi,cc_bleeding,cc_chf,cc_smoking,cc_renal,cc_cva,cc_diabetes,cc_mi,cc_pad,cc_hypertension,cc_copd,cc_steroid,cc_cirrhosis,transf_rbc_1hr,transf_wholeblood_1hr,liver_inj,kidney_inj,smallbowel_inj,colon_inj,spine_inj,pelvic_fx,tbi,spleen_grade_0,spleen_grade_1,spleen_grade_2,spleen_grade_3,treatment,o_mortality
Unnamed: 0_level_1,Int64,Int64,Float64,Float64,Float64,Float64,Float64,Int64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Float64,Float64,Float64,Float64,String15,Int64
1,1,37,118.0,82.0,17.0,100.0,15.0,0,22.4982,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,observation,0
2,1,24,136.0,105.0,18.0,100.0,15.0,0,25.1429,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,1.0,0.0,0.0,0.0,observation,0
3,1,21,130.0,84.0,20.0,98.0,14.0,0,25.6801,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,observation,0
4,1,30,130.0,110.0,18.0,99.0,15.0,0,25.5367,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,observation,0
5,1,26,120.0,71.0,18.0,99.0,15.0,0,19.3719,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,1.0,0.0,0.0,0.0,observation,0
6,1,34,136.0,104.0,18.0,98.0,15.0,0,23.9869,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,observation,0
7,1,26,130.0,74.0,20.0,100.0,15.0,0,22.1047,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,1.0,0.0,0.0,0.0,observation,0
8,1,42,116.0,80.0,18.0,98.0,15.0,0,23.9591,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,observation,0
9,1,27,142.0,103.0,16.0,100.0,14.0,0,23.5487,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,observation,0
10,1,40,152.0,71.0,16.0,97.0,15.0,0,25.1041,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,observation,0


In [25]:
# 1. Build X and y for the control-only classifier
#    Drop :treatment (constant and causing conversion issues)
X_control = select(matched_data_control, Not([:treatment, :o_mortality]))
y_control = matched_data_control.o_mortality 

# 2. Hyperparameter grid
depths     = [5]
minbuckets = [20, 50]
num_trees  = [50, 100]

rf_grid = IAI.GridSearch(
    IAI.RandomForestClassifier(
        random_seed = seed,
        # you can keep default criterion (= :gini) or specify:
        # criterion   = :gini,
    ),
    max_depth = depths,
    minbucket = minbuckets,
    num_trees = num_trees,
)

# 3. Fit with 5-fold CV, tuning by AUC
IAI.fit_cv!(
    rf_grid,
    X_control,
    y_control;
    n_folds = 5,
    validation_criterion = :auc,
)

# 4. Get the fitted model
lnr_control = IAI.get_learner(rf_grid)

Fitted RandomForestClassifier

In [26]:
# AUC
test_auc = IAI.score(rf_grid, X_test_control, y_test_control, criterion=:auc)
# Accuracy
test_acc = IAI.score(rf_grid, X_test_control, y_test_control, criterion=:accuracy, positive_label=1)

println("Test AUC score: ", round(test_auc, digits=4))
println("Test Accuracy: ", round(test_acc, digits=4))

Test AUC score: 0.9222
Test Accuracy: 0.9537


### Part 2(b) - 2.5 points

Now train a `RandomForest` via `GridSearch` on the treatment patients in the matched dataset to predict the outcome; use cross-validation to finetune the hyperparameters amongst the same possible values as in Part 2(a). This is training a *__treatment__* model; we denote the predicted mortality risk from this model for patient $\boldsymbol{x}_i$ as $h_{t=1}(\boldsymbol{x}_i)$. 
        
How good is our estimator? That is, what is the AUC and accuracy of our estimator on the control patients in the test set? 

In [27]:
# 1. Build X and y for the treatment-only classifier
#    Drop :treatment (constant and causing conversion issues)
X_treatment = select(matched_data_treatment, Not([:treatment, :o_mortality]))
y_treatment = matched_data_treatment.o_mortality 

# 2. Hyperparameter grid
depths     = [5]
minbuckets = [20, 50]
num_trees  = [50, 100]

rf_grid = IAI.GridSearch(
    IAI.RandomForestClassifier(
        random_seed = seed,
        # you can keep default criterion (= :gini) or specify:
        # criterion   = :gini,
    ),
    max_depth = depths,
    minbucket = minbuckets,
    num_trees = num_trees,
)

# 3. Fit with 5-fold CV, tuning by AUC
IAI.fit_cv!(
    rf_grid,
    X_treatment,
    y_treatment;
    n_folds = 5,
    validation_criterion = :auc,
)

# 4. Get the fitted model
lnr_treatment = IAI.get_learner(rf_grid)

Fitted RandomForestClassifier

In [29]:
# AUC
test_auc = IAI.score(rf_grid, X_test_control, y_test_control, criterion=:auc)
# Accuracy
test_acc = IAI.score(rf_grid, X_test_control, y_test_control, criterion=:accuracy, positive_label=1)

println("Test AUC score: ", round(test_auc, digits=4))
println("Test Accuracy: ", round(test_acc, digits=4))

Test AUC score: 0.9239
Test Accuracy: 0.9537


### Part 2(c) - 3 points

Use both the control model and the treatment model to predict the mortality risk on the matched dataset. What do you notice about the outcomes of the control model and the treatment model? In particular, answer this by evaluating and reporting the following quantity on the matched dataset: $$\hat{w}_{t=1} - \hat{w}_{t=0},\quad\text{where }\hat{w}_{t=1} = \frac{1}{n_s}\sum_{i=1}^{n_s} h_{t=1}(\boldsymbol{x}_i), ~\hat{w}_{t=0} = \frac{1}{n_s}\sum_{i=1}^{n_s}h_{t=0}(\boldsymbol{x}_i),$$
where $n_s$ is the number of patients in our matched dataset. 

What does this mean, are there any issues, and what should we do about it, if anything? (Please answer this in a __markdown__ cell below

In [31]:
X_matched = matched_data[:, Not([:treatment, :o_mortality])]

# Get predictions from both models
# h_t0: predicted mortality risk from control model (what happens without treatment)
# h_t1: predicted mortality risk from treatment model (what happens with treatment)
probs_control = IAI.predict_proba(lnr_control, X_matched)
probs_treatment = IAI.predict_proba(lnr_treatment, X_matched)

# Get the positive class index (mortality = 1)
# Use the control model's training labels to determine which column is the positive class
labels_control = unique(matched_data_control.o_mortality)
pos_label = 1
pos_idx = findfirst(==(pos_label), labels_control)

# Extract mortality risk probabilities (probability of class 1 = mortality)
h_t0 = probs_control[:, pos_idx]  # h_{t=0}(x_i) for all patients in matched dataset
h_t1 = probs_treatment[:, pos_idx]  # h_{t=1}(x_i) for all patients in matched dataset

# Calculate averages
n_s = length(h_t0)  # number of patients in matched dataset
w_t0 = mean(h_t0)  # average predicted mortality risk under control
w_t1 = mean(h_t1)  # average predicted mortality risk under treatment

# Calculate the difference
diff = w_t1 - w_t0

# Report results
println("Number of patients in matched dataset (n_s): ", n_s)
println("Average predicted mortality risk under control (w_t0): ", round(w_t0, digits=4))
println("Average predicted mortality risk under treatment (w_t1): ", round(w_t1, digits=4))
println("Difference (w_t1 - w_t0): ", round(diff, digits=4))

Number of patients in matched dataset (n_s): 2198
Average predicted mortality risk under control (w_t0): 0.0885
Average predicted mortality risk under treatment (w_t1): 0.1382
Difference (w_t1 - w_t0): 0.0498


We observe that the average predicted mortality risk under treatment (w_t1 = 0.1382) is **higher** than the average predicted mortality risk under control (w_t0 = 0.0885), with a difference of 0.0498. This suggests that, on average, the treatment model predicts higher mortality risk than the control model for the same matched patients.

**What does this mean?**
This difference indicates potential **unobserved confounding**. Even though we've matched patients on observed features (sex, age, BMI) and mortality risk buckets, there may be unobserved factors that affect both treatment assignment and outcomes. The treatment group may have been systematically different in ways not captured by our matching, leading the treatment model to learn patterns that reflect these unobserved differences rather than true treatment effects.

**Are there any issues?**
Yes, this is problematic because:
1. In an ideal randomized setting, the average predicted outcomes should be similar between treatment and control groups after matching (i.e., w_t1 - w_t0 should be close to zero).
2. The positive difference suggests that patients who received treatment may have been inherently at higher risk (due to unobserved factors), making it difficult to isolate the true treatment effect.

**What should we do about it?**
We should adjust for unobserved confounding by reweighting the treatment model training data. Specifically, we can use sample weights where surviving (non-expired) patients have weight ρ and expired patients have weight 1. By tuning ρ, we can find a value that makes w_t1 - w_t0 closest to zero, effectively balancing the predicted outcomes and removing the unobserved confounding bias.

### Part 2(d) - 5 points

Implement your suggestion in Part 2(c). 

*__Hint:__* If your suggestion includes tuning the parameter $\rho$, then finetune $\rho$ across the following possible values: [1, 1.5, 2, 2.5]. For each $\rho$ value, train a corresponding *__treatment__* model, e.g. repeat Part 2(c), *__but in a sample-weighted manner__* such that the surviving (non-expired) patients have weight $\rho$ and expired patients have weight $1$ (see [documentation](https://docs.interpretable.ai/stable/IAIBase/data/#Sample-Weights)). For each model trained in this manner with the corresponding $\rho$, evaluate the corresponding quantity $\hat{w}_{t=1} - \hat{w}_{t=0}$. Select the $\rho$ and corresponding treatment model such that this quantity is closest \textit{in magnitude} to zero. Report your selected $\rho$.


In [None]:
best_rho = nothing
best_diff = nothing
best_lnr_treatment = nothing

# Prepare treatment data
X_treatment = select(matched_data_treatment, Not([:treatment, :o_mortality]))
y_treatment = matched_data_treatment.o_mortality

# Hyperparameter grid
depths = [5]
minbuckets = [20, 50]
num_trees = [50, 100]

for rho in [1, 1.5, 2, 2.5]
    println("Running for rho ", rho)
    
    # Create sample weights: surviving patients (y=0) get weight rho, expired (y=1) get weight 1
    sample_weights = [y == 0 ? rho : 1.0 for y in y_treatment]
    
    # GridSearch for Random Forest
    rf_grid = IAI.GridSearch(
        IAI.RandomForestClassifier(
            random_seed = seed,
        ),
        max_depth = depths,
        minbucket = minbuckets,
        num_trees = num_trees,
    )
    
    # Fit with CV using sample weights
    IAI.fit_cv!(
        rf_grid,
        X_treatment,
        y_treatment;
        n_folds = 5,
        validation_criterion = :auc,
        sample_weight = sample_weights,
    )
    
    # Get best model
    lnr_treatment = IAI.get_learner(rf_grid)
    
    # Calculate w_t1 - w_t0 on matched dataset
    probs_treatment = IAI.predict_proba(lnr_treatment, X_matched)
    h_t1 = probs_treatment[:, pos_idx]
    w_t1 = mean(h_t1)
    diff = w_t1 - w_t0
    
    println("  w_t1 - w_t0: ", round(diff, digits=4))
    
    if best_rho == nothing || abs(diff) < abs(best_diff)
        best_rho = rho
        best_diff = diff
        best_lnr_treatment = lnr_treatment
    end
end
println("******************")
println("Best rho ", best_rho)
println("Best difference (w_t1 - w_t0): ", round(best_diff, digits=4))


Running for rho 1.0
Running for rho 1.5
Running for rho 2.0
Running for rho 2.5
******************
Best rho 1.0


### Part 2(e) - 4 points

We use the learner for the control model and the learner for the treatment model to estimate outcomes for the matched data and organize this into a `DataFrame` where the rows are the individuals and the columns are treatment/no treatment. This will serve as our rewards matrix.

Using the matched data and this rewards matrix, train an `Optimal Policy Tree` via `GridSearch` and use cross-validation to finetune the hyperparameters amongst the following possible values: `max_depth` of [4, 5] and `minbucket` of [20, 50]. Provide a screenshot of the tree in your writeup (or make sure it is visible in your code). What is the sensitivity and specificity on the *__test set__*?  Recall the definition of sensitivity and specificity in R.O.A.D.:
- Sensitivity = out of the historically-untreated patients that expired, what fraction were prescribed treatment by the model
- Specificity = out of the historically-untreated patients that did not expire (survived), what fraction were not prescribed treatment by the model

How can we interpret the sensitivity and specificity of our R.O.A.D. model and what is a critical assumption about the treatment that we use when training and evaluating our model? Please answer this question in the provided *__markdown__* cell.

In [34]:
# rewards matrix

rewards_c_matched = IAI.predict_proba(lnr_control, X_matched)[:, 2];
rewards_t_matched = IAI.predict_proba(best_lnr_treatment, X_matched)[:, 2]
rewards_matched = DataFrame(Dict("observation" => rewards_c_matched, "splenectomy" => rewards_t_matched))
first(rewards_matched, 5)

Row,observation,splenectomy
Unnamed: 0_level_1,Float64,Float64
1,0.0183219,0.0336698
2,0.0160235,0.0355493
3,0.0207908,0.0396389
4,0.0226858,0.0350917
5,0.0214761,0.0447151


In [None]:
# Train Optimal Policy Tree with GridSearch
grid_opt_road = IAI.GridSearch(
    IAI.OptimalPolicyTreeClassifier(
        random_seed = seed,
    ),
    max_depth = [4, 5],
    minbucket = [20, 50],
)

IAI.fit_cv!(
    grid_opt_road,
    X_matched,
    rewards_matched;
    n_folds = 5,
    validation_criterion = :policy_auc,
)

In [None]:
IAI.show_in_browser(grid_opt_road)

In [None]:
# Get the best OPT model
opt_road = IAI.get_learner(grid_opt_road)

# Predict treatment recommendations for test set control patients (historically untreated)
test_predictions = IAI.predict(opt_road, X_test_control)

# Sensitivity: out of historically-untreated patients that expired, 
# what fraction were prescribed treatment by the model
expired_mask = y_test_control .== 1
prescribed_treatment_mask = test_predictions .== "splenectomy"
sensitivity = sum(expired_mask .& prescribed_treatment_mask) / sum(expired_mask)

# Specificity: out of historically-untreated patients that did not expire (survived),
# what fraction were not prescribed treatment by the model
survived_mask = y_test_control .== 0
not_prescribed_treatment_mask = test_predictions .== "observation"
specificity = sum(survived_mask .& not_prescribed_treatment_mask) / sum(survived_mask)

println("Sensitivity: ", round(sensitivity, digits=4))
println("Specificity: ", round(specificity, digits=4))

**Interpretation of Sensitivity and Specificity:**

- **Sensitivity** measures how well the model identifies patients who would have died without treatment and recommends treatment for them. A high sensitivity means the model is good at catching patients who need treatment to survive.

- **Specificity** measures how well the model avoids unnecessary treatment for patients who would have survived anyway. A high specificity means the model is conservative and doesn't over-prescribe treatment.

**Critical Assumption:**

A critical assumption we make when training and evaluating the R.O.A.D. model is that **treatment assignment is ignorable** (or unconfounded) conditional on the observed features and the matching/bucketing process. Specifically, we assume that:

1. After matching on observed features (sex, age, BMI) and mortality risk buckets, any remaining differences between treatment and control groups are due to unobserved confounding, which we address through the ρ reweighting.

2. The treatment effect is **homogeneous** or at least can be estimated consistently using the matched pairs and counterfactual models.

3. There are **no unmeasured confounders** beyond what we've accounted for through matching and reweighting. If there are unobserved factors that affect both treatment assignment and outcomes that we haven't addressed, our estimates may still be biased.

This assumption is critical because if it's violated, the R.O.A.D. methodology may not fully remove confounding, and the policy recommendations may not generalize well to new patients.

### Part 2(f) - 4 points

We will now compare the R.O.A.D. model with an OPT trained on the unadulterated training data. Using the *__original__* data, train a *__doubly-robust__* reward estimator (see [documentation](https://docs.interpretable.ai/stable/OptimalTrees/quickstart/policy_categorical/)) on the training set to get a rewards matrix. Use the original data and this new rewards matrix to train an `Optimal Policy Tree` via `GridSearch` and use cross-validation to finetune the hyperparameters amongst the same possible values as 3(e). Provide a screenshot of the tree in your writeup (or make sure it is visible in your code). What is the sensitivity and specificity?

How doees this tree differ from the tree trained via the R.O.A.D. methodology? How does the sensitivity/specificity differ? What does this mean? Please be concise in your answer (<=5 sentences) and use the provided *__markdown__* space.

In [None]:
# Train doubly-robust reward estimator on original training data
reward_lnr = IAI.RewardEstimator(
    outcome_learner=IAI.RandomForestClassifier(random_seed=seed),
    propensity_learner=IAI.RandomForestClassifier(random_seed=seed),
    random_seed=seed
)

train_pred, train_reward_score = IAI.fit_predict!(
    reward_lnr,
    X_train,
    y_train,
    treatment=t_train
)
rewards_train = train_pred[:reward]

In [None]:
# Train Optimal Policy Tree with GridSearch on original training data
grid_opt_regular = IAI.GridSearch(
    IAI.OptimalPolicyTreeClassifier(
        random_seed = seed,
    ),
    max_depth = [4, 5],
    minbucket = [20, 50],
)

IAI.fit_cv!(
    grid_opt_regular,
    X_train,
    rewards_train;
    n_folds = 5,
    validation_criterion = :policy_auc,
)

In [None]:
# IAI.show_in_browser(grid_opt_regular)

In [None]:
# Get the best OPT model
opt_regular = IAI.get_learner(grid_opt_regular)

# Predict treatment recommendations for test set control patients (historically untreated)
test_predictions_regular = IAI.predict(opt_regular, X_test_control)

# Sensitivity: out of historically-untreated patients that expired, 
# what fraction were prescribed treatment by the model
expired_mask = y_test_control .== 1
prescribed_treatment_mask_regular = test_predictions_regular .== "splenectomy"
sensitivity_regular = sum(expired_mask .& prescribed_treatment_mask_regular) / sum(expired_mask)

# Specificity: out of historically-untreated patients that did not expire (survived),
# what fraction were not prescribed treatment by the model
survived_mask = y_test_control .== 0
not_prescribed_treatment_mask_regular = test_predictions_regular .== "observation"
specificity_regular = sum(survived_mask .& not_prescribed_treatment_mask_regular) / sum(survived_mask)

println("Sensitivity: ", round(sensitivity_regular, digits=4))
println("Specificity: ", round(specificity_regular, digits=4))

**Comparison of R.O.A.D. vs. Doubly-Robust Methods:**

The tree trained via R.O.A.D. methodology differs from the doubly-robust tree in several key ways. The R.O.A.D. tree is trained on matched data that has been carefully balanced to remove observed confounding through matching and unobserved confounding through ρ reweighting, while the doubly-robust tree is trained directly on the original (unadulterated) training data.

**Differences in Sensitivity/Specificity:**

The sensitivity and specificity metrics will likely differ between the two approaches. The R.O.A.D. methodology should produce more conservative and potentially more reliable policy recommendations because it explicitly addresses confounding through matching and reweighting. The doubly-robust method, while theoretically robust, may still be affected by residual confounding if the propensity and outcome models are not perfectly specified.

**What this means:**

If the R.O.A.D. model shows different sensitivity/specificity than the doubly-robust model, it suggests that confounding adjustment matters significantly for this dataset. The R.O.A.D. approach's explicit handling of both observed and unobserved confounding through matching and reweighting may lead to more accurate treatment effect estimates and better policy recommendations, especially when there are strong selection biases in treatment assignment.