# Brain Stroke Risk Prediction -- Exploratory Data Analysis

## The Problem

Stroke is the **2nd leading cause of death** globally, responsible for approximately 11% of total deaths according to the World Health Organization (WHO). Early identification of high-risk individuals can enable preventive interventions and save lives.

In this notebook, we explore a clinical dataset to understand the key risk factors associated with brain stroke, build intuition about the data, and prepare for predictive modeling.

---


In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / "src"))

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from stroke_risk.data.loader import load_data
from stroke_risk.features.engineering import engineer_features

# Plotting style
sns.set_theme(style="whitegrid", palette="muted", font_scale=1.2)
plt.rcParams["figure.figsize"] = (12, 6)
plt.rcParams["figure.dpi"] = 100

print("Setup complete.")


## 1. Loading the Dataset

The dataset contains **4,981 patient records** with 10 clinical features and a binary target indicating whether the patient experienced a stroke.


In [None]:
df = load_data()
print(f"Dataset shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
df.head(10)


In [None]:
df.info()


In [None]:
df.describe(include="all").round(2)


### Key Observations
- No missing values in the dataset.
- Features include both categorical (gender, work_type, etc.) and numerical (age, BMI, glucose) types.
- `smoking_status` has an "Unknown" category -- this may represent missing data.

---

## 2. Target Variable Analysis -- Class Imbalance

One of the most critical aspects of this dataset is the severe **class imbalance**. Let's quantify it.


In [None]:
stroke_counts = df["stroke"].value_counts()
stroke_pct = df["stroke"].value_counts(normalize=True) * 100

print("Stroke Distribution:")
print(f"  No Stroke (0): {stroke_counts[0]:,} ({stroke_pct[0]:.1f}%)")
print(f"  Stroke    (1): {stroke_counts[1]:,} ({stroke_pct[1]:.1f}%)")
print(f"  Imbalance Ratio: {stroke_counts[0] / stroke_counts[1]:.1f}:1")

fig = px.bar(
    x=["No Stroke", "Stroke"],
    y=stroke_counts.values,
    color=["No Stroke", "Stroke"],
    color_discrete_map={"No Stroke": "#4CAF50", "Stroke": "#F44336"},
    text=stroke_counts.values,
    title="Target Variable Distribution -- Severe Class Imbalance",
    labels={"x": "Class", "y": "Count"},
)
fig.update_traces(textposition="outside")
fig.update_layout(showlegend=False, height=400)
fig.show()


> **Insight**: Only ~5% of patients experienced a stroke. This extreme imbalance means naive models will achieve ~95% accuracy by simply predicting "No Stroke" for everyone -- making accuracy a misleading metric. We'll need specialized techniques like SMOTEENN resampling and threshold tuning.

---

## 3. Numerical Feature Distributions


In [None]:
numerical_cols = ["age", "avg_glucose_level", "bmi"]

fig = make_subplots(rows=1, cols=3, subplot_titles=numerical_cols)

for i, col in enumerate(numerical_cols, 1):
    for stroke_val, color, name in [(0, "#4CAF50", "No Stroke"), (1, "#F44336", "Stroke")]:
        subset = df[df["stroke"] == stroke_val][col]
        fig.add_trace(
            go.Histogram(
                x=subset, name=name, opacity=0.6,
                marker_color=color, showlegend=(i == 1),
            ),
            row=1, col=i,
        )

fig.update_layout(
    title="Numerical Feature Distributions by Stroke Status",
    barmode="overlay", height=400,
)
fig.show()


> **Insights**:
> - **Age** is the strongest differentiator -- stroke patients tend to be significantly older (60+ years).
> - **Average glucose level** shows a bimodal distribution in stroke patients, with a notable spike above 200 mg/dL.
> - **BMI** distributions are similar between groups, suggesting BMI alone is not a strong discriminator.

---

## 4. Categorical Feature Analysis


In [None]:
cat_cols = ["gender", "hypertension", "heart_disease", "ever_married",
            "work_type", "Residence_type", "smoking_status"]

fig, axes = plt.subplots(2, 4, figsize=(24, 12))
axes = axes.ravel()

for i, col in enumerate(cat_cols):
    # Calculate stroke rate per category
    ct = pd.crosstab(df[col], df["stroke"], normalize="index") * 100

    ct.plot(kind="bar", stacked=True, ax=axes[i],
            color=["#4CAF50", "#F44336"], alpha=0.8)
    axes[i].set_title(f"{col}", fontsize=13)
    axes[i].set_ylabel("Percentage")
    axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45, ha="right")
    axes[i].legend(["No Stroke", "Stroke"], fontsize=9)

axes[-1].set_visible(False)
plt.suptitle("Stroke Rate by Categorical Features", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()


> **Insights**:
> - **Hypertension** and **heart disease** show notably higher stroke rates.
> - **Married** individuals have higher stroke rates (likely correlated with age).
> - **Self-employed** workers show slightly elevated stroke risk.
> - **Gender** and **residence type** show minimal differences in stroke rates.

---

## 5. Correlation Analysis


In [None]:
# Encode categoricals for correlation
df_encoded = df.copy()
for col in ["gender", "ever_married", "work_type", "Residence_type", "smoking_status"]:
    df_encoded[col] = df_encoded[col].astype("category").cat.codes

corr_matrix = df_encoded.corr()

fig, ax = plt.subplots(figsize=(12, 10))
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
sns.heatmap(
    corr_matrix, mask=mask, annot=True, fmt=".2f", cmap="RdBu_r",
    center=0, square=True, linewidths=0.5, ax=ax,
    vmin=-1, vmax=1,
)
ax.set_title("Feature Correlation Matrix", fontsize=16)
plt.tight_layout()
plt.show()

# Correlation with target
print("\nCorrelation with Stroke (sorted):")
print(corr_matrix["stroke"].drop("stroke").sort_values(ascending=False).to_string())


> **Insights**:
> - **Age** has the highest positive correlation with stroke (~0.25).
> - **Hypertension**, **heart disease**, and **avg_glucose_level** also show positive correlations.
> - No features are strongly correlated with each other (no multicollinearity issues).

---

## 6. Age Deep-Dive -- The Strongest Predictor


In [None]:
fig = make_subplots(rows=1, cols=2, subplot_titles=["Age Distribution by Stroke", "Stroke Rate by Age Group"])

# Age distribution
for stroke_val, color, name in [(0, "#4CAF50", "No Stroke"), (1, "#F44336", "Stroke")]:
    subset = df[df["stroke"] == stroke_val]
    fig.add_trace(
        go.Histogram(x=subset["age"], name=name, opacity=0.6, marker_color=color, nbinsx=30),
        row=1, col=1,
    )

# Stroke rate by age group
df_temp = df.copy()
df_temp["age_group"] = pd.cut(df_temp["age"], bins=[0, 17, 39, 59, 74, 120],
                               labels=["0-17", "18-39", "40-59", "60-74", "75+"])
stroke_by_age = df_temp.groupby("age_group", observed=False)["stroke"].mean() * 100

fig.add_trace(
    go.Bar(
        x=stroke_by_age.index.astype(str), y=stroke_by_age.values,
        marker_color="#F44336", text=[f"{v:.1f}%" for v in stroke_by_age.values],
        textposition="outside", name="Stroke Rate", showlegend=False,
    ),
    row=1, col=2,
)

fig.update_layout(barmode="overlay", height=400, title="Age is the Strongest Predictor of Stroke")
fig.show()


> **Key Finding**: Stroke risk increases dramatically with age. Patients aged 75+ have a stroke rate of ~15%, compared to virtually 0% for those under 18. This aligns with medical literature.

---

## 7. Feature Engineering Preview

Based on our EDA insights, we've designed several engineered features to improve model performance:


In [None]:
df_eng = engineer_features(df)
new_cols = [c for c in df_eng.columns if c not in df.columns]

print(f"Original features: {len(df.columns)}")
print(f"After engineering: {len(df_eng.columns)}")
print(f"\nNew features ({len(new_cols)}):")
for col in new_cols:
    print(f"  - {col}: {df_eng[col].dtype} | example values: {df_eng[col].head(3).tolist()}")


In [None]:
# Risk score distribution
fig = px.histogram(
    df_eng, x="risk_score", color="stroke",
    color_discrete_map={0: "#4CAF50", 1: "#F44336"},
    barmode="overlay", opacity=0.6,
    title="Engineered Risk Score Distribution by Stroke Status",
    labels={"risk_score": "Composite Risk Score", "stroke": "Stroke"},
    category_orders={"stroke": [0, 1]},
)
fig.update_layout(height=400)
fig.show()


> **Insight**: The engineered risk score shows much better separation between stroke and non-stroke patients compared to any individual feature. This composite feature combines age, hypertension, heart disease, glucose, and BMI risk factors.

---

## Summary of EDA Findings

| Finding | Implication |
|---------|-------------|
| Severe class imbalance (95:5) | Need SMOTEENN / threshold tuning |
| Age is the strongest predictor | Age bins and age interactions are valuable |
| Hypertension and heart disease increase risk | Interaction features (age x condition) |
| High glucose correlates with stroke | Glucose category feature useful |
| BMI alone is weak | BMI x glucose interaction may help |
| No missing values | Clean dataset, no imputation needed |
| "Unknown" smoking status | Treat as its own category |

In the next notebook, we'll build, optimize, and evaluate multiple ML models using these insights.
