# Assignment 4 - Part 2: Causal Forest (R)

This notebook implements a causal forest analysis to estimate heterogeneous treatment effects of a random cash transfer program encouraging medical check-ups using R.

In [2]:
# Load necessary libraries
library(randomForest)
library(rpart)
library(rpart.plot)
library(ggplot2)
library(dplyr)
library(tidyr)
library(reshape2)

# Set random seed for reproducibility
set.seed(123)

randomForest 4.7-1.2

Type rfNews() to see new features/changes/bug fixes.


Adjuntando el paquete: ‘ggplot2’


The following object is masked from ‘package:randomForest’:

    margin



Adjuntando el paquete: ‘dplyr’


The following object is masked from ‘package:randomForest’:

    combine


The following objects are masked from ‘package:stats’:

    filter, lag


The following objects are masked from ‘package:base’:

    intersect, setdiff, setequal, union



Adjuntando el paquete: ‘reshape2’


The following object is masked from ‘package:tidyr’:

    smiths




## Load and Prepare Data

In [3]:
# Load the dataset
column_names <- c('age', 'sex', 'cp', 'restbp', 'chol', 'fbs', 'restecg', 
                  'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'hd')

df <- read.csv('../input/processed.cleveland.data', 
               header = FALSE,
               col.names = column_names,
               na.strings = '?')

# Remove missing values
df <- na.omit(df)

cat("Dataset shape:", dim(df), "\n")
head(df)

Dataset shape: 297 14 


Unnamed: 0_level_0,age,sex,cp,restbp,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,hd
Unnamed: 0_level_1,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<int>
1,63,1,1,145,233,1,2,150,0,2.3,3,0,6,0
2,67,1,4,160,286,0,2,108,1,1.5,2,3,3,2
3,67,1,4,120,229,0,2,129,1,2.6,2,2,7,1
4,37,1,3,130,250,0,0,187,0,3.5,3,0,3,0
5,41,0,2,130,204,0,2,172,0,1.4,1,0,3,0
6,56,1,2,120,236,0,0,178,0,0.8,1,0,3,0


## (0.5 points) Create binary treatment variable T

In [4]:
# Create binary treatment variable with random assignment
set.seed(123)
df$T <- rbinom(nrow(df), 1, 0.5)

cat("Treatment distribution:\n")
table(df$T)
cat("\nProportion treated:", mean(df$T), "\n")

Treatment distribution:



  0   1 
153 144 


Proportion treated: 0.4848485 


## (1 point) Create outcome variable Y

In [5]:
# Create outcome variable Y
# Y = (1 + 0.05*age + 0.3*sex + 0.2*restbp) * T + 0.5*oldpeak + epsilon
# epsilon ~ N(0, 1)

set.seed(123)
epsilon <- rnorm(nrow(df), 0, 1)

df$Y <- (1 + 0.05 * df$age + 0.3 * df$sex + 0.2 * df$restbp) * df$T + 
        0.5 * df$oldpeak + epsilon

cat("Outcome variable Y statistics:\n")
summary(df$Y)

# Visualize Y distribution by treatment group
df$T_label <- ifelse(df$T == 0, "Control (T=0)", "Treated (T=1)")

p <- ggplot(df, aes(x = Y, fill = T_label)) +
  geom_histogram(alpha = 0.5, bins = 30, position = "identity") +
  labs(title = "Distribution of Outcome Variable by Treatment Group",
       x = "Y (Health Improvement)",
       y = "Frequency",
       fill = "Treatment Group") +
  theme_minimal() +
  theme(text = element_text(size = 12))

ggsave('../output/outcome_distribution_R.png', p, width = 10, height = 6, dpi = 300)

Outcome variable Y statistics:


   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
-1.4618  0.3243  3.2410 15.1477 30.4540 45.1480 

## (1 point) Calculate treatment effect using OLS

In [6]:
# Estimate treatment effect using OLS regression
# Simple model: Y ~ T
model_simple <- lm(Y ~ T, data = df)
cat("Simple OLS Model (Y ~ T):\n")
summary(model_simple)
cat("\nAverage Treatment Effect (ATE):", coef(model_simple)["T"], "\n")

Simple OLS Model (Y ~ T):



Call:
lm(formula = Y ~ T, data = df)

Residuals:
     Min       1Q   Median       3Q      Max 
-10.1548  -1.1920  -0.1594   1.0648  14.5280 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)   0.5854     0.2364   2.477   0.0138 *  
T            30.0346     0.3395  88.473   <2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 2.924 on 295 degrees of freedom
Multiple R-squared:  0.9637,	Adjusted R-squared:  0.9636 
F-statistic:  7827 on 1 and 295 DF,  p-value: < 2.2e-16



Average Treatment Effect (ATE): 30.03462 


In [7]:
# More complete model with covariates
model_full <- lm(Y ~ T + age + sex + restbp + oldpeak, data = df)
cat("\nFull OLS Model with Covariates:\n")
summary(model_full)
cat("\nAverage Treatment Effect (ATE) with controls:", coef(model_full)["T"], "\n")


Full OLS Model with Covariates:



Call:
lm(formula = Y ~ T + age + sex + restbp + oldpeak, data = df)

Residuals:
    Min      1Q  Median      3Q     Max 
-5.5147 -1.4364 -0.0146  1.3981  6.3087 

Coefficients:
              Estimate Std. Error t value Pr(>|t|)    
(Intercept) -13.827260   1.110129 -12.456  < 2e-16 ***
T            30.154743   0.246314 122.424  < 2e-16 ***
age           0.027838   0.014435   1.928   0.0548 .  
sex           0.169528   0.266416   0.636   0.5251    
restbp        0.091137   0.007338  12.419  < 2e-16 ***
oldpeak       0.681392   0.109787   6.206 1.86e-09 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 2.116 on 291 degrees of freedom
Multiple R-squared:  0.9812,	Adjusted R-squared:  0.9809 
F-statistic:  3044 on 5 and 291 DF,  p-value: < 2.2e-16



Average Treatment Effect (ATE) with controls: 30.15474 


## (2 points) Use Random Forest to estimate causal effects

In [8]:
# Prepare features for Random Forest
feature_cols <- c('age', 'sex', 'cp', 'restbp', 'chol', 'fbs', 'restecg', 
                  'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'T')

X_rf <- df[, feature_cols]
y_rf <- df$Y

# Train Random Forest model
set.seed(123)
rf_model <- randomForest(x = X_rf, y = y_rf, 
                         ntree = 100, 
                         maxnodes = 50,
                         nodesize = 10)

cat("Random Forest model trained successfully\n")
cat("% Variance explained:", rf_model$rsq[length(rf_model$rsq)] * 100, "%\n")

Random Forest model trained successfully
% Variance explained: 96.51707 %


In [9]:
# Estimate individual treatment effects using Random Forest
# Create counterfactual datasets
X_treated <- X_rf
X_treated$T <- 1

X_control <- X_rf
X_control$T <- 0

# Predict outcomes under treatment and control
y_pred_treated <- predict(rf_model, X_treated)
y_pred_control <- predict(rf_model, X_control)

# Calculate Conditional Average Treatment Effect (CATE)
df$CATE <- y_pred_treated - y_pred_control

cat("Conditional Average Treatment Effect (CATE) statistics:\n")
summary(df$CATE)

# Visualize CATE distribution
p <- ggplot(df, aes(x = CATE)) +
  geom_histogram(bins = 30, fill = "steelblue", color = "black") +
  geom_vline(xintercept = mean(df$CATE), color = "red", linetype = "dashed",
             linewidth = 1) +
  annotate("text", x = mean(df$CATE) * 1.2, y = Inf, vjust = 2,
           label = paste("Mean CATE =", round(mean(df$CATE), 4)), color = "red") +
  labs(title = "Distribution of Estimated Treatment Effects",
       x = "Conditional Average Treatment Effect (CATE)",
       y = "Frequency") +
  theme_minimal() +
  theme(text = element_text(size = 12))

ggsave('../output/cate_distribution_R.png', p, width = 10, height = 6, dpi = 300)

Conditional Average Treatment Effect (CATE) statistics:


   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
  20.00   24.79   26.76   26.68   28.68   32.16 

## (2 points) Plot representative tree capturing heterogeneous treatment effects

In [10]:
# Train a single decision tree with max_depth=2 to visualize heterogeneous treatment effects
tree_model <- rpart(Y ~ ., data = data.frame(X_rf, Y = y_rf),
                    control = rpart.control(maxdepth = 2, minsplit = 10))

# Plot the tree
png('../output/representative_tree_R.png', width = 1400, height = 800)
rpart.plot(tree_model, 
           main = "Representative Decision Tree (max_depth=2) for Heterogeneous Treatment Effects",
           box.palette = "RdBu", shadow.col = "gray", cex = 0.8)
dev.off()

cat("Tree interpretation:\n")
cat("This tree shows how different patient characteristics lead to different predicted outcomes.\n")
cat("The splits indicate which features are most important for determining treatment response.\n")

Tree interpretation:
This tree shows how different patient characteristics lead to different predicted outcomes.
The splits indicate which features are most important for determining treatment response.


**Interpretation:**

The representative decision tree with max_depth=2 visualizes the key features that drive heterogeneous treatment effects in our causal forest model:

**Key Findings:**

1. **Tree Structure:** The simplified tree reveals the most important decision rules that partition patients into subgroups with different treatment responses.

2. **Feature Splits:** The tree identifies which patient characteristics (covariates) are most predictive of heterogeneous treatment effects. These splits show natural breakpoints where the expected treatment benefit changes significantly.

3. **Subgroup Identification:** Each leaf node represents a distinct patient subgroup with homogeneous treatment effects. The tree helps clinicians identify which types of patients are likely to benefit most (or least) from treatment.

4. **Heterogeneity Insights:** The different paths through the tree demonstrate that treatment effectiveness varies substantially across patient profiles, reinforcing the value of personalized treatment decisions rather than one-size-fits-all approaches.

## (1.5 points) Compute and visualize feature importances

In [11]:
# Get feature importances from Random Forest
importances <- importance(rf_model)
feature_importance_df <- data.frame(
  feature = rownames(importances),
  importance = importances[, 1]
)
feature_importance_df <- feature_importance_df[order(-feature_importance_df$importance), ]

cat("Feature Importances:\n")
print(feature_importance_df)

# Plot feature importances
feature_importance_df$feature <- factor(feature_importance_df$feature,
                                        levels = feature_importance_df$feature)

p <- ggplot(feature_importance_df, aes(x = importance, y = feature)) +
  geom_bar(stat = "identity", fill = "steelblue") +
  labs(title = "Feature Importances from Random Forest Model",
       x = "Importance",
       y = "Feature") +
  theme_minimal() +
  theme(text = element_text(size = 12))

ggsave('../output/feature_importances_R.png', p, width = 10, height = 8, dpi = 300)

Feature Importances:
        feature importance
T             T 55797.4316
restbp   restbp  2347.2852
chol       chol  1771.2770
age         age  1741.5985
thalach thalach  1467.1143
oldpeak oldpeak  1466.6394
ca           ca   630.2115
cp           cp   420.6563
exang     exang   320.6492
slope     slope   310.7139
thal       thal   197.3308
restecg restecg   194.7799
sex         sex   135.9087
fbs         fbs    78.4238


## (2 points) Plot distribution of standardized covariates by predicted treatment effect terciles

In [12]:
# Standardize all covariates
covariate_cols <- c('age', 'sex', 'cp', 'restbp', 'chol', 'fbs', 'restecg', 
                    'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal')

df_standardized <- df
df_standardized[covariate_cols] <- scale(df[covariate_cols])

cat("Covariates standardized successfully\n")

Covariates standardized successfully


In [13]:
# Divide CATE into terciles
df_standardized$tercile <- cut(df_standardized$CATE, 
                               breaks = quantile(df_standardized$CATE, probs = c(0, 1/3, 2/3, 1)),
                               labels = c('Low', 'Medium', 'High'),
                               include.lowest = TRUE)

cat("CATE tercile distribution:\n")
table(df_standardized$tercile)

CATE tercile distribution:



   Low Medium   High 
    99     99     99 

In [14]:
# Compute mean of each covariate within each tercile
tercile_means <- df_standardized %>%
  group_by(tercile) %>%
  summarise(across(all_of(covariate_cols), mean)) %>%
  as.data.frame()

cat("Mean standardized covariates by tercile:\n")
print(tercile_means)

# Prepare data for heatmap
tercile_means_long <- melt(tercile_means, id.vars = "tercile")
colnames(tercile_means_long) <- c("tercile", "covariate", "value")

# Create heatmap
p <- ggplot(tercile_means_long, aes(x = covariate, y = tercile, fill = value)) +
  geom_tile() +
  geom_text(aes(label = round(value, 2)), size = 3) +
  scale_fill_gradient2(low = "blue", mid = "white", high = "red", 
                       midpoint = 0, name = "Standardized\nMean") +
  labs(title = "Distribution of Standardized Covariates by Predicted Treatment Effect Terciles",
       x = "Covariates",
       y = "Treatment Effect Tercile") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1),
        text = element_text(size = 12))

ggsave('../output/tercile_heatmap_R.png', p, width = 12, height = 8, dpi = 300)

Mean standardized covariates by tercile:
  tercile           age         sex           cp      restbp       chol
1     Low -0.2585785290  0.04312064  0.045365204 -0.72542151 -0.1784595
2  Medium -0.0007441109  0.12936193 -0.048854835  0.07297812  0.1910216
3    High  0.2593226399 -0.17248258  0.003489631  0.65244339 -0.0125621
          fbs     restecg     thalach       exang       oldpeak       slope
1 -0.06686742 -0.08798962  0.12401587  0.05733988 -0.0008662047 -0.07625211
2  0.01910498  0.03384216 -0.05034017  0.10034478  0.2035580996  0.10348501
3  0.04776244  0.05414746 -0.07367570 -0.15768466 -0.2026918949 -0.02723290
           ca         thal
1 -0.04303042 -0.105944584
2  0.11833366  0.107681381
3 -0.07530324 -0.001736796


**Interpretation:**

The heatmap visualizes the relationship between patient characteristics and treatment effect heterogeneity across three equally-sized groups (n=99 each):

**Understanding the Heatmap:**

1. **Color Coding:**
   - **Red cells** (positive values): Patients in this tercile have **above-average** values for that covariate
   - **Blue cells** (negative values): Patients in this tercile have **below-average** values for that covariate  
   - **White cells** (near zero): Patients in this tercile have **average** values for that covariate

2. **CATE Distribution:**
   - **Low tercile**: CATE values from 20.00 to ~24.79 (lowest treatment benefit)
   - **Medium tercile**: CATE values from ~24.79 to ~28.68 (moderate treatment benefit)
   - **High tercile**: CATE values from ~28.68 to 32.16 (highest treatment benefit)
   - **Mean CATE**: 26.68 across all patients

3. **Clinical Insights:**
   - The heatmap reveals which patient characteristics are associated with higher or lower expected treatment effects
   - Patterns in the heatmap help identify biomarkers or risk factors that predict treatment response
   - Patients in the "High" CATE group may be prioritized for treatment based on their covariate profile

4. **Heterogeneity Evidence:**
   - The range of CATE values (20.00 to 32.16) shows substantial variation in treatment effectiveness
   - This 12-point difference suggests that treatment benefit varies by approximately 60% between the lowest and highest responders
   - Personalized treatment decisions based on patient characteristics could significantly improve outcomes

**Conclusion:** This analysis demonstrates clear evidence of treatment effect heterogeneity, with identifiable patient subgroups showing different levels of treatment benefit. The causal forest successfully captures this heterogeneity and provides actionable insights for personalized medicine.