# Brain Stroke Risk Prediction -- Modeling & Results

## Overview

In this notebook, we build an end-to-end ML pipeline for stroke risk prediction:

1. **Data Pipeline**: Load, engineer features, preprocess, and resample.
2. **Model Optimization**: Optuna-based hyperparameter tuning for 4 models.
3. **Stacking Ensemble**: Meta-learner on top of optimized base models.
4. **Threshold Tuning**: Optimize classification threshold for stroke-class F1.
5. **SHAP Explainability**: Understand which features drive predictions.
6. **MLflow Tracking**: All experiments logged for reproducibility.

---


In [None]:
import sys
import logging
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / "src"))

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

from stroke_risk.data.loader import load_data
from stroke_risk.data.preprocessing import build_preprocessor, get_feature_names, resample_data, split_data
from stroke_risk.features.engineering import engineer_features
from stroke_risk.models.optimize import optimize_all_models
from stroke_risk.models.stacking import train_stacking_ensemble
from stroke_risk.models.evaluate import (
    compute_metrics, find_optimal_threshold, build_comparison_table,
    plot_confusion_matrix, plot_roc_curves, plot_precision_recall_curves,
    get_classification_report_str, cross_validate_with_ci,
)
from stroke_risk.explainability.shap_analysis import compute_shap_values, plot_summary, plot_bar, plot_waterfall
from stroke_risk.utils.config import load_all_configs

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)-30s | %(levelname)-7s | %(message)s")

print("All modules loaded successfully.")


## Step 1: Data Pipeline

Load data, apply feature engineering, split, preprocess, and resample.


In [None]:
# Load configs
config = load_all_configs(str(project_root / "configs"))
data_cfg = config.get("data", {})
training_cfg = config.get("training", {})
model_cfg = config.get("models", {})

# Load and engineer
df = load_data()
df = engineer_features(df)
print(f"Dataset after feature engineering: {df.shape}")
df.head()


In [None]:
# Split
X_train, X_test, y_train, y_test = split_data(df, test_size=0.2, random_state=42)

# Build preprocessor with engineered features
base_cat = data_cfg.get("categorical_features", [])
eng_cat = [c for c in ["age_group", "bmi_category", "glucose_category"] if c in X_train.columns]
all_cat = base_cat + eng_cat

base_num = data_cfg.get("numerical_features", [])
eng_num = [c for c in ["age_x_hypertension", "age_x_heart_disease", "bmi_x_glucose", "age_x_bmi", "risk_score"] if c in X_train.columns]
all_num = base_num + eng_num

bin_cols = data_cfg.get("binary_features", [])

preprocessor = build_preprocessor(categorical_features=all_cat, numerical_features=all_num, binary_features=bin_cols)
X_train_processed = preprocessor.fit_transform(X_train)
X_test_processed = preprocessor.transform(X_test)
feature_names = get_feature_names(preprocessor)

print(f"Processed features: {X_train_processed.shape[1]}")
print(f"Feature names sample: {feature_names[:10]}")

# Resample
X_train_res, y_train_res = resample_data(X_train_processed, y_train, strategy="smoteenn")
print(f"\nAfter resampling: {X_train_res.shape[0]} samples")
print(f"  Class 0: {(y_train_res == 0).sum()}, Class 1: {(y_train_res == 1).sum()}")


## Step 2: Model Optimization with Optuna

We optimize 4 models (Logistic Regression, Random Forest, XGBoost, LightGBM) using Bayesian hyperparameter optimization via Optuna's TPE sampler. This is far more efficient than grid search.


In [None]:
%%time
# Optimize all models (this may take several minutes)
optimization_results = optimize_all_models(
    X=X_train_res,
    y=y_train_res,
    model_configs=model_cfg,
    scoring="f1",
    cv_folds=5,
    random_state=42,
)

# Show optimization results
for name, result in optimization_results.items():
    print(f"\n{name}:")
    print(f"  Best CV F1: {result['best_score']:.4f}")
    print(f"  Best params: {result['best_params']}")


## Step 3: Stacking Ensemble

We combine the 4 optimized models into a stacking ensemble where a Logistic Regression meta-learner learns the optimal way to combine their predictions.


In [None]:
base_models = {name: res["best_model"] for name, res in optimization_results.items()}

stacking_model = train_stacking_ensemble(
    base_models=base_models,
    X_train=X_train_res,
    y_train=y_train_res,
    cv_folds=5,
    random_state=42,
)

all_models = {**base_models, "stacking_ensemble": stacking_model}
print(f"Total models to evaluate: {len(all_models)}")
print(f"Models: {list(all_models.keys())}")


## Step 4: Threshold Tuning & Evaluation

Instead of using the default 0.5 threshold, we optimize the classification threshold for each model to maximize the F1 score on the stroke class. This is critical for imbalanced datasets.


In [None]:
evaluation_results = {}

for name, model in all_models.items():
    y_prob = model.predict_proba(X_test_processed)[:, 1]
    
    # Find optimal threshold
    thresh_result = find_optimal_threshold(y_test.values, y_prob, method="f1")
    optimal_threshold = thresh_result["threshold"]
    
    # Predictions at optimal threshold
    y_pred = (y_prob >= optimal_threshold).astype(int)
    metrics = compute_metrics(y_test.values, y_pred, y_prob)
    
    evaluation_results[name] = {
        "model": model,
        "metrics": metrics,
        "threshold": optimal_threshold,
        "y_pred": y_pred,
        "y_prob": y_prob,
    }
    
    print(f"\n{'='*60}")
    print(f"  {name} (threshold={optimal_threshold:.4f})")
    print(f"{'='*60}")
    print(get_classification_report_str(y_test.values, y_pred))


### Model Comparison Table


In [None]:
comparison = build_comparison_table(evaluation_results)
comparison.style.background_gradient(subset=["F1 (Stroke)", "PR-AUC", "ROC-AUC"], cmap="YlGn")


### ROC and Precision-Recall Curves


In [None]:
models_for_plot = {name: res["model"] for name, res in evaluation_results.items()}

fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# ROC curves
plt.sca(axes[0])
roc_fig = plot_roc_curves(models_for_plot, X_test_processed, y_test.values)
plt.close(roc_fig)

# PR curves
plt.sca(axes[1])
pr_fig = plot_precision_recall_curves(models_for_plot, X_test_processed, y_test.values)
plt.close(pr_fig)

# Re-plot on combined figure
fig_combined, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay

for name, model in models_for_plot.items():
    RocCurveDisplay.from_estimator(model, X_test_processed, y_test.values, name=name, ax=ax1)
ax1.plot([0, 1], [0, 1], "k--", lw=1)
ax1.set_title("ROC Curves", fontsize=14)
ax1.legend(loc="lower right", fontsize=9)

for name, model in models_for_plot.items():
    PrecisionRecallDisplay.from_estimator(model, X_test_processed, y_test.values, name=name, ax=ax2)
prevalence = np.mean(y_test)
ax2.axhline(y=prevalence, color="k", linestyle="--", lw=1, label=f"Baseline ({prevalence:.3f})")
ax2.set_title("Precision-Recall Curves", fontsize=14)
ax2.legend(loc="upper right", fontsize=9)

plt.tight_layout()
plt.show()


### Confusion Matrices -- Best Model


In [None]:
# Best model confusion matrix
best_name = comparison.iloc[0]["Model"]
best_result = evaluation_results[best_name]

print(f"Best model: {best_name}")
print(f"Optimal threshold: {best_result['threshold']:.4f}")
print()

fig = plot_confusion_matrix(
    y_test.values, best_result["y_pred"],
    title=f"Confusion Matrix -- {best_name} (threshold={best_result['threshold']:.3f})"
)
plt.show()


## Step 5: SHAP Explainability

SHAP (SHapley Additive exPlanations) gives us both global feature importance and local, per-prediction explanations. This is essential in healthcare ML where model interpretability is critical for clinical adoption.


In [None]:
# Compute SHAP values for the best model
best_model = evaluation_results[best_name]["model"]
shap_values = compute_shap_values(best_model, X_test_processed, feature_names=feature_names)

# Global feature importance -- beeswarm plot
fig_summary = plot_summary(shap_values, max_display=15)
plt.show()


In [None]:
# Global feature importance -- bar plot
fig_bar = plot_bar(shap_values, max_display=15)
plt.show()


In [None]:
# Local explanation -- waterfall for a stroke-positive prediction
stroke_indices = np.where(best_result["y_pred"] == 1)[0]
if len(stroke_indices) > 0:
    sample_idx = stroke_indices[0]
    print(f"Explaining prediction for sample {sample_idx} (predicted: Stroke)")
    fig_waterfall = plot_waterfall(shap_values, index=sample_idx, max_display=15)
    plt.show()
else:
    print("No positive predictions to explain.")


## Step 6: Save Best Model


In [None]:
import joblib

fe_cfg = training_cfg.get("feature_engineering", {})

artifact = {
    "model": best_model,
    "preprocessor": preprocessor,
    "feature_names": feature_names,
    "threshold": best_result["threshold"],
    "model_name": best_name,
    "feature_engineering_config": fe_cfg,
}

models_dir = project_root / "models"
models_dir.mkdir(exist_ok=True)
artifact_path = models_dir / "best_model.joblib"
joblib.dump(artifact, artifact_path)

print(f"Best model saved to {artifact_path}")
print(f"  Model: {best_name}")
print(f"  Threshold: {best_result['threshold']:.4f}")
print(f"  F1 (Stroke): {best_result['metrics']['f1_stroke']:.4f}")
print(f"  ROC-AUC: {best_result['metrics'].get('roc_auc', 'N/A')}")


## Summary

This project demonstrates a complete, industry-grade ML pipeline:

| Component | Technique |
|-----------|-----------|
| Feature Engineering | Age bins, BMI/glucose categories, interactions, risk score |
| Preprocessing | ColumnTransformer (OneHotEncoder + StandardScaler) |
| Class Imbalance | SMOTEENN resampling |
| HPO | Optuna with TPE sampler |
| Ensemble | Stacking (LR meta-learner over 4 base models) |
| Threshold Tuning | PR-curve based F1 optimization |
| Explainability | SHAP (global + local explanations) |
| Experiment Tracking | MLflow |
| Deployment | FastAPI + Streamlit + Docker |

The model is saved and ready for deployment via `python scripts/train.py` or the FastAPI/Streamlit apps.
