# Assignment 4: Predicting Heart Disease Using Decision Trees and Causal Forest

This notebook implements:
1. Classification tree for heart disease prediction
2. Causal forest analysis for treatment effects

## Part 1: Predicting Heart Disease Using a Classification Tree

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(123)

### 1.1 Data Cleaning (2 points)

In [None]:
# Load the data
column_names = ['age', 'sex', 'cp', 'restbp', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'hd']
df = pd.read_csv('../input/processed.cleveland.data', names=column_names, na_values='?')

print(f"Original dataset shape: {df.shape}")
print(f"\nMissing values:\n{df.isnull().sum()}")

# Remove missing values
df = df.dropna()
print(f"\nDataset shape after removing missing values: {df.shape}")

In [None]:
# Create binary variable y (1 if heart disease, 0 otherwise)
# Original hd: 0 = no disease, 1-4 = disease levels
df['y'] = (df['hd'] > 0).astype(int)

print(f"Distribution of heart disease:")
print(df['y'].value_counts())
print(f"\nPercentage with heart disease: {df['y'].mean()*100:.2f}%")

In [None]:
# Identify categorical variables
categorical_vars = ['sex', 'cp', 'fbs', 'restecg', 'exang', 'slope', 'ca', 'thal']

# Create dummy variables for categorical features
df_encoded = pd.get_dummies(df, columns=categorical_vars, drop_first=True)

print(f"Dataset shape after creating dummy variables: {df_encoded.shape}")
print(f"\nColumn names: {list(df_encoded.columns)}")

In [None]:
# Prepare features and target
X = df_encoded.drop(['hd', 'y'], axis=1)
y = df_encoded['y']

print(f"Features shape: {X.shape}")
print(f"Target shape: {y.shape}")

### 1.2 Data Analysis (8 points)

#### (1 point) Split data and plot classification tree

In [None]:
# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123)

print(f"Training set size: {X_train.shape[0]}")
print(f"Test set size: {X_test.shape[0]}")

In [None]:
# Train initial classification tree (without pruning)
clf = DecisionTreeClassifier(random_state=123)
clf.fit(X_train, y_train)

# Plot the tree
plt.figure(figsize=(20, 10))
plot_tree(clf, filled=True, feature_names=X.columns, class_names=["No HD", "Has HD"], fontsize=10)
plt.title("Initial Classification Tree (Unpruned)")
plt.savefig('../output/initial_tree.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Tree depth: {clf.get_depth()}")
print(f"Number of leaves: {clf.get_n_leaves()}")

#### (2 points) Plot confusion matrix and interpret

In [None]:
# Predictions on test set
y_pred = clf.predict(X_test)

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
accuracy = accuracy_score(y_test, y_pred)

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Does not have HD", "Has HD"])
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title(f'Confusion Matrix - Initial Tree\nAccuracy: {accuracy:.4f}')
plt.savefig('../output/confusion_matrix_initial.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTest Accuracy: {accuracy:.4f}")
print(f"\nConfusion Matrix:")
print(cm)
print(f"\nTrue Negatives: {cm[0, 0]}")
print(f"False Positives: {cm[0, 1]}")
print(f"False Negatives: {cm[1, 0]}")
print(f"True Positives: {cm[1, 1]}")

**Interpretation of Initial Confusion Matrix:**
- The initial tree may show overfitting characteristics with high training accuracy but lower test accuracy
- True Positives: Correctly identified patients with heart disease
- True Negatives: Correctly identified patients without heart disease
- False Positives: Patients incorrectly classified as having heart disease (Type I error)
- False Negatives: Patients incorrectly classified as not having heart disease (Type II error - more concerning in medical diagnosis)

#### (1.5 points) Fix overfitting using cross-validation

In [None]:
# Generate 50 alpha values equally spaced on a logarithmic scale between e^-10 and 0.05
alphas = np.logspace(-10, np.log10(0.05), 50)
print(f"Alpha range: {alphas.min():.10f} to {alphas.max():.4f}")
print(f"Number of alphas: {len(alphas)}")

In [None]:
# Perform 4-fold cross-validation to select optimal alpha
from sklearn.model_selection import KFold

kfold = KFold(n_splits=4, shuffle=True, random_state=123)
cv_scores = []

for alpha in alphas:
    clf_cv = DecisionTreeClassifier(ccp_alpha=alpha, random_state=123)
    scores = cross_val_score(clf_cv, X_train, y_train, cv=kfold, scoring='accuracy')
    cv_scores.append(scores.mean())

cv_scores = np.array(cv_scores)
optimal_idx = np.argmax(cv_scores)
optimal_alpha = alphas[optimal_idx]
optimal_cv_score = cv_scores[optimal_idx]

print(f"Optimal alpha: {optimal_alpha:.10f}")
print(f"Best CV accuracy: {optimal_cv_score:.4f}")

#### (1.5 points) Plot Inaccuracy Rate (1 - Accuracy) against alpha

In [None]:
# Calculate inaccuracy rate
inaccuracy_rates = 1 - cv_scores

# Plot inaccuracy rate vs alpha
plt.figure(figsize=(12, 6))
plt.plot(alphas, inaccuracy_rates, marker='o', markersize=3, linewidth=2)
plt.axvline(optimal_alpha, color='r', linestyle='--', label=f'Optimal α = {optimal_alpha:.6f}')
plt.xscale('log')
plt.xlabel('Alpha (log scale)', fontsize=12)
plt.ylabel('Inaccuracy Rate (1 - Accuracy)', fontsize=12)
plt.title('Inaccuracy Rate vs Alpha (4-fold Cross-Validation)', fontsize=14)
plt.grid(True, alpha=0.3)
plt.legend(fontsize=10)
plt.savefig('../output/inaccuracy_vs_alpha.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Minimum inaccuracy rate: {inaccuracy_rates.min():.4f}")
print(f"Maximum inaccuracy rate: {inaccuracy_rates.max():.4f}")

#### (2 points) Plot optimal tree and confusion matrix with interpretation

In [None]:
# Train tree with optimal alpha
clf_optimal = DecisionTreeClassifier(ccp_alpha=optimal_alpha, random_state=123)
clf_optimal.fit(X_train, y_train)

# Plot optimal tree
plt.figure(figsize=(20, 10))
plot_tree(clf_optimal, filled=True, feature_names=X.columns, class_names=["No HD", "Has HD"], fontsize=10)
plt.title(f"Optimal Classification Tree (α = {optimal_alpha:.6f})")
plt.savefig('../output/optimal_tree.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Optimal tree depth: {clf_optimal.get_depth()}")
print(f"Optimal number of leaves: {clf_optimal.get_n_leaves()}")

In [None]:
# Predictions with optimal tree
y_pred_optimal = clf_optimal.predict(X_test)

# Compute confusion matrix
cm_optimal = confusion_matrix(y_test, y_pred_optimal)
accuracy_optimal = accuracy_score(y_test, y_pred_optimal)

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
disp = ConfusionMatrixDisplay(confusion_matrix=cm_optimal, display_labels=["Does not have HD", "Has HD"])
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title(f'Confusion Matrix - Optimal Tree (α = {optimal_alpha:.6f})\nAccuracy: {accuracy_optimal:.4f}')
plt.savefig('../output/confusion_matrix_optimal.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nOptimal Test Accuracy: {accuracy_optimal:.4f}")
print(f"Initial Test Accuracy: {accuracy:.4f}")
print(f"\nConfusion Matrix (Optimal):")
print(cm_optimal)
print(f"\nTrue Negatives: {cm_optimal[0, 0]}")
print(f"False Positives: {cm_optimal[0, 1]}")
print(f"False Negatives: {cm_optimal[1, 0]}")
print(f"True Positives: {cm_optimal[1, 1]}")

# Calculate additional metrics
sensitivity = cm_optimal[1, 1] / (cm_optimal[1, 1] + cm_optimal[1, 0])
specificity = cm_optimal[0, 0] / (cm_optimal[0, 0] + cm_optimal[0, 1])
print(f"\nSensitivity (True Positive Rate): {sensitivity:.4f}")
print(f"Specificity (True Negative Rate): {specificity:.4f}")

**Interpretation and Discussion:**

1. **Tree Complexity**: The optimal tree with regularization (α) is simpler than the initial unpruned tree, reducing overfitting.

2. **Performance Comparison**: 
   - The pruned tree may have slightly lower training accuracy but better generalization on test data
   - The cross-validation process helped select an alpha that balances bias and variance

3. **Clinical Implications**:
   - Sensitivity measures the ability to correctly identify patients with heart disease
   - Specificity measures the ability to correctly identify patients without heart disease
   - In medical diagnosis, high sensitivity is often preferred to avoid missing cases (minimize false negatives)

4. **Model Insights**: 
   - The tree reveals which features are most important for heart disease prediction
   - The pruning process removed splits that didn't significantly improve accuracy
   - The optimal model provides interpretable decision rules for clinical use

## Part 2: Causal Forest Analysis

### (0.5 points) Create binary treatment variable T

In [None]:
# Reset random seed
np.random.seed(123)

# Create binary treatment variable (random assignment)
df['T'] = np.random.binomial(1, 0.5, size=len(df))

print(f"Treatment distribution:")
print(df['T'].value_counts())
print(f"\nProportion treated: {df['T'].mean():.4f}")

### (1 point) Create outcome variable Y

In [None]:
# Generate outcome variable Y based on the specified formula
# Y = (1 + 0.05*age + 0.3*sex + 0.2*restbp) * T + 0.5*oldpeak + ε
# where ε ~ N(0, 1)

epsilon = np.random.normal(0, 1, size=len(df))
df['Y'] = (1 + 0.05 * df['age'] + 0.3 * df['sex'] + 0.2 * df['restbp']) * df['T'] + 0.5 * df['oldpeak'] + epsilon

print(f"Outcome variable Y statistics:")
print(df['Y'].describe())
print(f"\nMean Y for treated: {df[df['T']==1]['Y'].mean():.4f}")
print(f"Mean Y for control: {df[df['T']==0]['Y'].mean():.4f}")
print(f"Raw difference: {df[df['T']==1]['Y'].mean() - df[df['T']==0]['Y'].mean():.4f}")

### (1 point) Calculate treatment effect using OLS

In [None]:
# OLS regression: Y ~ T
from sklearn.linear_model import LinearRegression
import scipy.stats as stats

X_ols = df[['T']].values
y_ols = df['Y'].values

ols_model = LinearRegression()
ols_model.fit(X_ols, y_ols)

treatment_effect_ols = ols_model.coef_[0]
intercept = ols_model.intercept_

# Calculate standard errors and p-values
predictions = ols_model.predict(X_ols)
residuals = y_ols - predictions
n = len(y_ols)
k = 1  # number of predictors
dof = n - k - 1
mse = np.sum(residuals**2) / dof

# Standard error
var_T = np.var(X_ols, ddof=1)
se_coef = np.sqrt(mse / (n * var_T))
t_stat = treatment_effect_ols / se_coef
p_value = 2 * (1 - stats.t.cdf(np.abs(t_stat), dof))

print("OLS Regression Results: Y ~ T")
print("="*50)
print(f"Intercept: {intercept:.4f}")
print(f"Treatment Effect (β_T): {treatment_effect_ols:.4f}")
print(f"Standard Error: {se_coef:.4f}")
print(f"t-statistic: {t_stat:.4f}")
print(f"p-value: {p_value:.6f}")
print(f"R-squared: {ols_model.score(X_ols, y_ols):.4f}")

### (2 points) Use Random Forest to estimate causal effects

In [None]:
# Prepare features for causal forest (excluding treatment and outcome)
covariates = ['age', 'sex', 'cp', 'restbp', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal']
X_cf = df[covariates].copy()
T_cf = df['T'].values
Y_cf = df['Y'].values

# Since we're using sklearn's Random Forest, we'll estimate heterogeneous treatment effects
# by creating interaction features with treatment
X_cf_with_treatment = X_cf.copy()
X_cf_with_treatment['T'] = T_cf

# Add interaction terms (T * covariates)
for col in covariates:
    X_cf_with_treatment[f'T_x_{col}'] = X_cf_with_treatment['T'] * X_cf_with_treatment[col]

# Fit Random Forest
rf_model = RandomForestRegressor(n_estimators=100, random_state=123, max_depth=10, min_samples_leaf=5)
rf_model.fit(X_cf_with_treatment, Y_cf)

# Predict outcomes under treatment and control for each individual
X_treated = X_cf_with_treatment.copy()
X_treated['T'] = 1
for col in covariates:
    X_treated[f'T_x_{col}'] = X_treated[col]

X_control = X_cf_with_treatment.copy()
X_control['T'] = 0
for col in covariates:
    X_control[f'T_x_{col}'] = 0

Y_pred_treated = rf_model.predict(X_treated)
Y_pred_control = rf_model.predict(X_control)

# Individual treatment effects
individual_treatment_effects = Y_pred_treated - Y_pred_control
df['ITE'] = individual_treatment_effects

print("Random Forest Causal Effects Estimation")
print("="*50)
print(f"Average Treatment Effect (ATE): {individual_treatment_effects.mean():.4f}")
print(f"Standard Deviation of ITE: {individual_treatment_effects.std():.4f}")
print(f"Min ITE: {individual_treatment_effects.min():.4f}")
print(f"Max ITE: {individual_treatment_effects.max():.4f}")
print(f"\nR-squared: {rf_model.score(X_cf_with_treatment, Y_cf):.4f}")

# Plot distribution of treatment effects
plt.figure(figsize=(10, 6))
plt.hist(individual_treatment_effects, bins=30, edgecolor='black', alpha=0.7)
plt.axvline(individual_treatment_effects.mean(), color='r', linestyle='--', linewidth=2, label=f'Mean ITE = {individual_treatment_effects.mean():.4f}')
plt.xlabel('Individual Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Distribution of Individual Treatment Effects', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.savefig('../output/ite_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

### (2 points) Plot representative tree with max_depth=2

In [None]:
# Train a single decision tree with max_depth=2 to visualize treatment effect heterogeneity
tree_model = DecisionTreeRegressor(max_depth=2, random_state=123, min_samples_leaf=10)
tree_model.fit(X_cf_with_treatment, Y_cf)

# Plot the tree
plt.figure(figsize=(20, 10))
plot_tree(tree_model, filled=True, feature_names=X_cf_with_treatment.columns, fontsize=10, rounded=True)
plt.title("Representative Tree (max_depth=2) for Treatment Effect Heterogeneity", fontsize=14)
plt.savefig('../output/representative_tree.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Tree depth: {tree_model.get_depth()}")
print(f"Number of leaves: {tree_model.get_n_leaves()}")

**Interpretation of Representative Tree:**

This shallow tree (max_depth=2) reveals the most important heterogeneous treatment effects:
- The tree shows which patient characteristics lead to different treatment effects
- Each split represents a key decision point that differentiates treatment response
- Leaf nodes show the predicted outcome for patients in that subgroup
- The interaction terms (T_x_*) capture how treatment effects vary by patient characteristics

### (1.5 points) Feature importance visualization

In [None]:
# Get feature importances from Random Forest
feature_importance = pd.DataFrame({
    'feature': X_cf_with_treatment.columns,
    'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=False)

# Plot top 15 most important features
plt.figure(figsize=(12, 8))
top_features = feature_importance.head(15)
plt.barh(range(len(top_features)), top_features['importance'], color='steelblue')
plt.yticks(range(len(top_features)), top_features['feature'])
plt.xlabel('Importance', fontsize=12)
plt.ylabel('Feature', fontsize=12)
plt.title('Top 15 Feature Importances (Random Forest)', fontsize=14)
plt.gca().invert_yaxis()
plt.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.savefig('../output/feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nTop 10 Most Important Features:")
print(feature_importance.head(10).to_string(index=False))

### (2 points) Covariate distribution by treatment effect terciles

In [None]:
# Standardize covariates
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_standardized = pd.DataFrame(
    scaler.fit_transform(X_cf),
    columns=covariates,
    index=X_cf.index
)

# Divide predicted treatment effects into terciles
df['ITE_tercile'] = pd.qcut(df['ITE'], q=3, labels=['Low', 'Medium', 'High'])

print("Treatment Effect Terciles:")
print(df.groupby('ITE_tercile')['ITE'].describe())

In [None]:
# Compute mean of each standardized covariate within each tercile
tercile_means = []

for tercile in ['Low', 'Medium', 'High']:
    mask = df['ITE_tercile'] == tercile
    tercile_mean = X_standardized[mask].mean()
    tercile_means.append(tercile_mean)

# Create DataFrame for heatmap
heatmap_data = pd.DataFrame(tercile_means, index=['Low', 'Medium', 'High'])

# Create heatmap
plt.figure(figsize=(14, 6))
sns.heatmap(heatmap_data, annot=True, fmt='.2f', cmap='RdBu_r', center=0, 
            cbar_kws={'label': 'Mean Standardized Value'}, linewidths=0.5)
plt.xlabel('Covariates', fontsize=12)
plt.ylabel('Treatment Effect Tercile', fontsize=12)
plt.title('Mean Standardized Covariates by Predicted Treatment Effect Terciles', fontsize=14)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('../output/terciles_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nMean Standardized Covariates by Tercile:")
print(heatmap_data)

**Interpretation of Tercile Analysis:**

This heatmap shows how patient characteristics differ across treatment effect terciles:

- **Low tercile**: Patients with lowest predicted treatment effects
- **Medium tercile**: Patients with moderate predicted treatment effects  
- **High tercile**: Patients with highest predicted treatment effects

The color intensity indicates how each covariate's mean differs from zero (the population mean after standardization):
- Red colors indicate above-average values for that tercile
- Blue colors indicate below-average values for that tercile

This analysis helps identify which patient characteristics are associated with higher or lower treatment effects, informing targeted intervention strategies.

## Summary of Results

### Part 1: Classification Tree
- Successfully built and pruned a classification tree for heart disease prediction
- Used cross-validation to find optimal complexity parameter (alpha)
- Achieved reasonable accuracy while maintaining interpretability

### Part 2: Causal Forest
- Estimated heterogeneous treatment effects using Random Forest
- Identified key characteristics associated with treatment response
- Visualized treatment effect heterogeneity across patient subgroups

All figures have been saved to the `../output/` directory.