# Stroke Prediction Analysis

This notebook implements a complete data science workflow for stroke prediction, addressing class imbalance and comparing multiple models including Random Forest, Logistic Regression, XGBoost, and a Neural Network.

## 1. Imports & Setup

In [None]:
import polars as pl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from plotnine import *

from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, precision_score, recall_score, f1_score

from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline

from xgboost import XGBClassifier

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, callbacks

# Set random seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

## 2. Data Loading
Loading the dataset using Polars. We handle potential parsing errors by increasing the schema inference length and specifying null values.

In [None]:
# Load data with robust error handling settings
stroke_df = pl.read_csv(
    "healthcare-dataset-stroke-data.csv", 
    null_values=["N/A", ""],
    infer_schema_length=10000
)

print(f"Dataset shape: {stroke_df.shape}")
stroke_df.head()

## 3. Data Cleaning & Feature Engineering

We perform the following steps:
1.  Impute missing BMI values with the median.
2.  Remove the single observation with "Other" gender.
3.  Create binary encodings for categorical variables.
4.  Engineer new features: Age groups, BMI categories, Glucose categories, and a composite Risk Score.
5.  Cast the target variable to integer.

In [None]:
import polars as pl
import pandas as pd
import numpy as np
from sklearn.impute import KNNImputer
from sklearn.preprocessing import OrdinalEncoder, StandardScaler

# 1. Load and Initial Filter
# Fix: Added infer_schema_length to scan more rows and null_values to handle "N/A" strings
df = pl.read_csv(
    "healthcare-dataset-stroke-data.csv", 
    infer_schema_length=10000, 
    null_values=["N/A"]
)
df = df.filter(pl.col("gender") != "Other")

# 2. Prepare Data for KNN Imputation
# We use Pandas/Sklearn for this specific step as Polars lacks a native KNN Imputer.
pandas_df = df.to_pandas()

# Identify columns
cat_cols = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']
num_cols = ['age', 'avg_glucose_level', 'hypertension', 'heart_disease']
target_impute = ['bmi']

# A. Temporarily Encode Categoricals (Ordinal is fine for distance approximation)
encoder = OrdinalEncoder()
df_encoded = pandas_df.copy()
df_encoded[cat_cols] = encoder.fit_transform(df_encoded[cat_cols])

# B. Scale Data (Crucial for KNN so 'age' doesn't dominate 'glucose')
scaler = StandardScaler()
cols_to_scale = num_cols + target_impute
df_encoded[cols_to_scale] = scaler.fit_transform(df_encoded[cols_to_scale])

# C. Apply KNN Imputer
imputer = KNNImputer(n_neighbors=5)
# We impute on everything except ID and Stroke
impute_data = df_encoded[cat_cols + num_cols + target_impute]
imputed_matrix = imputer.fit_transform(impute_data)

# D. Recover BMI
# The imputer returns an array in the same order. We extract the last column (BMI),
# inverse transform the scaling, and assign it back.
imputed_bmi_scaled = imputed_matrix[:, -1].reshape(-1, 1)

# Calculate mean/std manually from original pandas_df for inverse transform
# (We do this because we fit the scaler on multiple columns, making standard inverse_transform tricky for just one column)
bmi_mean = pandas_df['bmi'].mean()
bmi_std = pandas_df['bmi'].std()
final_bmi = (imputed_bmi_scaled * bmi_std) + bmi_mean

# 3. Back to Polars for Feature Engineering
# We inject the imputed BMI back into the original Polars DataFrame
stroke_final = (
    df
    .with_columns(pl.Series(name="bmi", values=final_bmi.flatten()))
    
    # --- Feature Engineering ---
    
    # 1. Binary Encodings
    .with_columns(
        pl.col("ever_married").replace({"Yes": 1, "No": 0}).cast(pl.Int8).alias("ever_married_binary"),
        pl.col("Residence_type").replace({"Urban": 1, "Rural": 0}).cast(pl.Int8).alias("residence_urban"),
        pl.col("gender").replace({"Male": 1, "Female": 0}).cast(pl.Int8).alias("gender_male"),
        pl.col("hypertension").cast(pl.Int8),
        pl.col("heart_disease").cast(pl.Int8),
        pl.col("stroke").cast(pl.Int8),
    )
    
    # 2. Binning (Age Groups)
    .with_columns(
        pl.when(pl.col("age") < 18).then(pl.lit("0-17"))
        .when(pl.col("age") < 40).then(pl.lit("18-39"))
        .when(pl.col("age") < 60).then(pl.lit("40-59"))
        .when(pl.col("age") < 80).then(pl.lit("60-79"))
        .otherwise(pl.lit("80+"))
        .alias("age_group")
    )
    
    # 3. BMI Categories
    .with_columns(
        pl.when(pl.col("bmi") < 18.5).then(pl.lit("Underweight"))
        .when(pl.col("bmi") < 25).then(pl.lit("Normal"))
        .when(pl.col("bmi") < 30).then(pl.lit("Overweight"))
        .otherwise(pl.lit("Obese"))
        .alias("bmi_category")
    )
    
    # 4. Glucose Categories
    .with_columns(
        pl.when(pl.col("avg_glucose_level") < 100).then(pl.lit("Normal"))
        .when(pl.col("avg_glucose_level") < 126).then(pl.lit("Prediabetic"))
        .otherwise(pl.lit("Diabetic"))
        .alias("glucose_category")
    )
    
    # 5. Risk Score Interaction
    .with_columns(
        (
            pl.col("hypertension") + 
            pl.col("heart_disease") + 
            (pl.col("age") >= 55).cast(pl.Int8) +
            (pl.col("avg_glucose_level") >= 126).cast(pl.Int8) +
            (pl.col("bmi") >= 30).cast(pl.Int8)
        ).alias("risk_score")
    )
    .drop("id")
)

# Export for non-CV models / EDA
stroke_final.write_csv("stroke_cleaned.csv")
print("Cleaned data exported with KNN imputation.")

## 4. CV Pipeline

We split the data into training and testing sets, then apply scaling and one-hot encoding. Crucially, we apply SMOTE (Synthetic Minority Over-sampling Technique) only to the training data to address the class imbalance without causing data leakage.

In [None]:
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from imblearn.pipeline import Pipeline as ImbPipeline
from imblearn.over_sampling import SMOTE
from sklearn.ensemble import RandomForestClassifier

# Custom Feature Engineer for the Pipeline
class FeatureEngineer(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        X = X.copy()
        # Add Age Group
        X['age_group'] = pd.cut(X['age'], bins=[0,18,45,65,100], labels=['child','adult','senior','elderly'])
        # Add Risk Interaction (Glucose * Age)
        X['glucose_age_interaction'] = X['avg_glucose_level'] * X['age']
        return X

# 1. Prepare Data (Raw split)
# We use the pandas version of the raw filtered data
X = pandas_df.drop(['stroke', 'id'], axis=1)
y = pandas_df['stroke']

# 2. Train/Test Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# 3. Define Preprocessing (Impute -> Scale -> Encode)
numeric_features = ['age', 'avg_glucose_level', 'bmi', 'glucose_age_interaction']
categorical_features = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status', 'age_group']

numeric_transformer = Pipeline(steps=[
    ('scaler', StandardScaler()), 
    ('imputer', KNNImputer(n_neighbors=5)) 
])

categorical_transformer = Pipeline(steps=[
    ('encoder', OneHotEncoder(handle_unknown='ignore', sparse_output=False))
])

preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_features)
    ],
    remainder='drop'
)

# 4. Construct the Master CV Pipeline
cv_pipeline = ImbPipeline(steps=[
    ('feature_engineering', FeatureEngineer()),  # Creates columns
    ('preprocessor', preprocessor),              # Handles NaNs, Scaling, Encoding
    ('smote', SMOTE(random_state=42)),           # Resampling inside the fold
    ('model', RandomForestClassifier(random_state=42))
])

# 5. Run CV
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(cv_pipeline, X_train, y_train, cv=cv, scoring='roc_auc')

print(f"CV ROC-AUC Scores: {scores}")
print(f"Mean ROC-AUC: {scores.mean():.4f}")

## 5. Model Training & Evaluation

We train and compare three models: Random Forest, Logistic Regression, and XGBoost. All models are trained on the SMOTE-resampled training data.

In [None]:
from sklearn.model_selection import cross_validate
from xgboost import XGBClassifier
from sklearn.ensemble import StackingClassifier, RandomForestClassifier
from sklearn.linear_model import LogisticRegression

# 1. Define the Models (Same as your list)
# Note: For Stacking, we keep cv=5 internal, but we will also cross-validate the whole stack externally (Nested CV).
base_learners = [
    ('rf', RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)),
    ('lr', LogisticRegression(class_weight="balanced", max_iter=1000, random_state=42))
]

models = {
    "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1),
    "Logistic Regression": LogisticRegression(class_weight="balanced", max_iter=1000, random_state=42),
    "XGBoost": XGBClassifier(n_estimators=200, max_depth=6, learning_rate=0.1, scale_pos_weight=5, random_state=42, eval_metric="logloss"),
    "Stacking Classifier": StackingClassifier(estimators=base_learners, final_estimator=LogisticRegression(), cv=5)
}

# 2. Define Scoring Metrics
# We want to track all these metrics during CV
scoring_metrics = {
    'accuracy': 'accuracy',
    'recall': 'recall',
    'precision': 'precision',
    'f1': 'f1',
    'roc_auc': 'roc_auc'
}

# 3. The CV Loop
cv_results_data = []
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

print("Starting Cross-Validation Pipeline...")

for name, model in models.items():
    print(f"Evaluating {name}...")
    
    # Create a fresh pipeline for this specific model
    # We inject the model into the 'classifier' step defined in the previous solution
    current_pipeline = ImbPipeline(steps=[
        ('engineer', FeatureEngineer()),      # Feature Engineering
        ('preprocessor', preprocessor),       # Imputation & Scaling
        ('smote', SMOTE(random_state=42)),    # SMOTE (Inside CV!)
        ('classifier', model)                 # The specific model being tested
    ])
    
    # Run Cross-Validation
    # CRITICAL: We pass X_train, y_train (the raw data), NOT the resampled data.
    # The pipeline handles the resampling for us.
    scores = cross_validate(
        current_pipeline, 
        X_train, 
        y_train, 
        cv=cv, 
        scoring=scoring_metrics,
        n_jobs=-1  # Use all CPU cores
    )
    
    # Store the MEAN score across all 5 folds
    cv_results_data.append({
        "Model": name,
        "Accuracy": scores['test_accuracy'].mean(),
        "Recall": scores['test_recall'].mean(),
        "Precision": scores['test_precision'].mean(),
        "F1 Score": scores['test_f1'].mean(),
        "ROC-AUC": scores['test_roc_auc'].mean()
    })

# 4. Display Comparison
results_df = pd.DataFrame(cv_results_data)
print("\n=== Cross-Validation Performance (Mean of 5 Folds) ===")
print(results_df.round(4).sort_values(by='ROC-AUC', ascending=False))

## 6. Elastic Net Grid Search

We perform a grid search to optimize the Logistic Regression model using Elastic Net regularization.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, confusion_matrix
from imblearn.pipeline import Pipeline as ImbPipeline
from imblearn.over_sampling import SMOTE

# 1. Setup Pipeline and Grid Search
# ---------------------------------------------------------
# We define the full pipeline including SMOTE
lr_pipeline = ImbPipeline(steps=[
    ('engineer', FeatureEngineer()),  # Our custom class from before
    ('preprocessor', preprocessor),   # The ColumnTransformer from before
    ('smote', SMOTE(random_state=42)),
    ('classifier', LogisticRegression(
        penalty="elasticnet",
        solver="saga",
        max_iter=2000,
        random_state=42,
        class_weight="balanced"
    ))
])

# Note the double underscore prefix 'classifier__' to target the step
param_grid = {
    "classifier__C": [0.001, 0.01, 0.1, 1, 10, 100],
    "classifier__l1_ratio": [0, 0.25, 0.5, 0.75, 1.0]
}

grid_search = GridSearchCV(
    lr_pipeline,
    param_grid,
    cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=42),
    scoring="roc_auc",
    n_jobs=-1
)

print("Running Grid Search on Pipeline...")
# CRITICAL: Pass raw X_train/y_train. The pipeline handles the rest.
grid_search.fit(X_train, y_train)

# 2. Extract Best Model and Evaluate
# ---------------------------------------------------------
best_pipeline = grid_search.best_estimator_
best_params = grid_search.best_params_

print(f"\n=== BEST MODEL SELECTED ===")
print(f"Parameters: {best_params}")

# Generate Predictions (Pass raw X_test, pipeline handles transformation)
y_pred = best_pipeline.predict(X_test)
y_proba = best_pipeline.predict_proba(X_test)[:, 1]

# Calculate Metrics
acc = accuracy_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_proba)
rec = recall_score(y_test, y_pred)
prec = precision_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)

print(f"\n=== TEST SET METRICS ===")
print(f"Accuracy:  {acc:.4f}")
print(f"ROC-AUC:   {auc:.4f}")
print(f"Recall:    {rec:.4f}")
print(f"Precision: {prec:.4f}")
print(f"F1 Score:  {f1:.4f}")
print("\nConfusion Matrix:")
print(cm)

# 3. Generate Coefficient Path Diagram
# ---------------------------------------------------------
print("\nGenerating Coefficient Path...")

# We need to manually preprocess the data once so we can loop through C values quickly
# (Running the full pipeline 30 times would be slow)

# A. Extract steps from the fitted best pipeline
engineer_step = best_pipeline.named_steps['engineer']
preprocessor_step = best_pipeline.named_steps['preprocessor']
smote_step = best_pipeline.named_steps['smote']

# B. Transform X_train to the state the classifier sees
X_eng = engineer_step.transform(X_train)         # Add features
X_scaled = preprocessor_step.transform(X_eng)    # Scale/Encode
X_resampled, y_resampled = smote_step.fit_resample(X_scaled, y_train) # Apply SMOTE

# C. Get Feature Names
try:
    feature_names = preprocessor_step.get_feature_names_out()
except:
    feature_names = [f"Feature {i}" for i in range(X_resampled.shape[1])]

Cs = param_grid['classifier__C']
best_l1 = best_params['classifier__l1_ratio']
coeffs = []

# Loop through C values using just the classifier part
for c in Cs:
    clf = LogisticRegression(
        penalty="elasticnet",
        solver="saga",
        C=c,
        l1_ratio=best_l1,
        max_iter=2000,
        random_state=42,
        class_weight="balanced"
    )
    # Fit on the pre-processed, resampled data
    clf.fit(X_resampled, y_resampled)
    coeffs.append(clf.coef_.ravel())

coeffs = np.array(coeffs)

# Plotting
plt.figure(figsize=(12, 8))
for i in range(coeffs.shape[1]):
    plt.plot(np.log10(Cs), coeffs[:, i], marker='o', label=feature_names[i])

best_c = best_params['classifier__C']
plt.axvline(x=np.log10(best_c), color='black', linestyle='--', label=f"Best C ({best_c})")

plt.title(f"Coefficient Path (Elastic Net, l1_ratio={best_l1})")
plt.xlabel("log10(C) - Inverse Regularization Strength")
plt.ylabel("Coefficient Value")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') 
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# 4. Standard Heatmap
# ---------------------------------------------------------
results_pd = pd.DataFrame(grid_search.cv_results_)
pivot_table = results_pd.pivot(index="param_classifier__C", columns="param_classifier__l1_ratio", values="mean_test_score")

plt.figure(figsize=(8, 5))
sns.heatmap(pivot_table, annot=True, fmt=".4f", cmap="viridis")
plt.title("Grid Search AUC Performance")
plt.xlabel("L1 Ratio")
plt.ylabel("C")
plt.show()

## 7. Lasso Coefficient Analysis

We analyze the coefficients of the best Elastic Net model to see which features were shrunk to zero.

In [None]:
import shap

# --- SHAP ANALYSIS (Suggestion 4) ---
# We use a Random Forest model for the explanation as it handles non-linearities well.
# We retrain a fresh instance on the processed data to ensure the explainer has full access.

print("Training Explainer Model (Random Forest)...")
rf_explainer = RandomForestClassifier(n_estimators=100, random_state=42)
rf_explainer.fit(X_train_resampled, y_train_resampled)

# Initialize SHAP Tree Explainer
explainer = shap.TreeExplainer(rf_explainer)

# Calculate SHAP values for the test set
# We take a sample of the test set if it's too large, but here it's small enough.
shap_values = explainer.shap_values(X_test_processed)

# Get feature names for the plot
# We need to combine numeric, binary, and one-hot encoded categorical names
cat_feature_names = list(preprocessor.named_transformers_["cat"].get_feature_names_out(categorical_features))
# Note: We added interaction terms in the previous step, so we need to ensure feature names align.
# The 'numeric_features' list in Cell 5 needs to match what was used in 'ColumnTransformer'.
# Assuming 'preprocessor' was fitted on the DataFrame that INCLUDED the interaction terms:
feature_names_all = numeric_features + binary_features + cat_feature_names

print("Generating SHAP Summary Plot...")
# Summary plot for the positive class (Stroke = 1)
shap.summary_plot(shap_values[1], X_test_processed, feature_names=feature_names_all)

# Optional: Force plot for a single high-risk prediction
# shap.force_plot(explainer.expected_value[1], shap_values[1][0,:], X_test_processed[0,:], feature_names=feature_names_all, matplotlib=True)

## 8. Neural Network

We implement a binary classification neural network using TensorFlow/Keras. We check for Apple Metal (MPS) acceleration availability.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.metrics import roc_auc_score, recall_score, precision_score, f1_score, confusion_matrix, precision_recall_curve, auc
import numpy as np
from imblearn.over_sampling import SMOTE

# ---------------------------------------------------------
# 1. Manual Preprocessing for PyTorch
# (Since we can't put a PyTorch model directly into an sklearn Pipeline easily)
# ---------------------------------------------------------

# A. Feature Engineering
fe = FeatureEngineer()
X_train_eng = fe.fit_transform(X_train)
X_test_eng = fe.transform(X_test)

# B. Scaling & Imputation (Fit on Train, Transform Test)
X_train_processed = preprocessor.fit_transform(X_train_eng)
X_test_processed = preprocessor.transform(X_test_eng)

# C. SMOTE (Only on Training Data)
smote = SMOTE(random_state=42)
X_train_resampled, y_train_resampled = smote.fit_resample(X_train_processed, y_train)

# ---------------------------------------------------------
# 2. PyTorch Setup
# ---------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Check for MPS (Apple Silicon)
if torch.backends.mps.is_available():
    device = torch.device("mps")
print(f"Running on device: {device}")

# Convert to Tensors
X_train_tensor = torch.tensor(X_train_resampled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train_resampled.values, dtype=torch.float32).unsqueeze(1)
X_test_tensor = torch.tensor(X_test_processed, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32).unsqueeze(1)

# Create DataLoaders
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True) # Increased batch size for stability

# ---------------------------------------------------------
# 3. Define Model
# ---------------------------------------------------------
class StrokeClassifier(nn.Module):
    def __init__(self, input_dim):
        super(StrokeClassifier, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(32, 1), # Output layer
            # No Sigmoid here because we use BCEWithLogitsLoss
        )

    def forward(self, x):
        return self.main(x)

model = StrokeClassifier(input_dim=X_train_resampled.shape[1]).to(device)

# ---------------------------------------------------------
# 4. Weighted Loss & Optimizer
# ---------------------------------------------------------
# We calculate weight based on the resampled data (though it is balanced now, 
# sometimes adding a small weight helps focus learning).
pos_weight_value = torch.tensor([1.0]).to(device) # SMOTE balanced it, so roughly 1:1
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_value)

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# ---------------------------------------------------------
# 5. Training Loop
# ---------------------------------------------------------
epochs = 100
patience = 15
best_val_auc = 0
patience_counter = 0
best_model_state = None
history = {'loss': [], 'val_auc': [], 'val_recall': []}

print(f"\nStarting Training (Max {epochs} epochs)...")

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        logits = model(X_batch)
        loss = criterion(logits, y_batch)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    
    # Validation phase
    model.eval()
    with torch.no_grad():
        X_test_gpu = X_test_tensor.to(device)
        test_logits = model(X_test_gpu)
        test_probs = torch.sigmoid(test_logits).cpu().numpy()
        
        # Calculate Val AUC
        val_auc = roc_auc_score(y_test, test_probs)
        val_recall = recall_score(y_test, (test_probs > 0.5).astype(int))

    history['loss'].append(avg_loss)
    history['val_auc'].append(val_auc)
    history['val_recall'].append(val_recall)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | Val AUC: {val_auc:.4f} | Recall(0.5): {val_recall:.4f}")
    
    # Early Stopping
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        patience_counter = 0
        best_model_state = model.state_dict()
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            model.load_state_dict(best_model_state)
            break

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)

# ---------------------------------------------------------
# 6. Visualization
# ---------------------------------------------------------
# A. Plot History
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(history['loss'], label='Train Loss')
plt.title('Training Loss')

plt.subplot(1, 3, 2)
plt.plot(history['val_auc'], label='Val AUC', color='green')
plt.title('Validation ROC-AUC')

plt.subplot(1, 3, 3)
plt.plot(history['val_recall'], label='Val Recall', color='orange')
plt.title('Validation Recall')
plt.show()

# B. Multi-Threshold Analysis
model.eval()
with torch.no_grad():
    final_logits = model(X_test_tensor.to(device))
    final_probs = torch.sigmoid(final_logits).cpu().numpy()

# Metric Calculations
precision_curve, recall_curve, thresholds = precision_recall_curve(y_test, final_probs)
f1_scores = 2 * (precision_curve * recall_curve) / (precision_curve + recall_curve + 1e-10)
best_f1_idx = np.argmax(f1_scores)
best_f1_thresh = thresholds[best_f1_idx]

# 90% Recall Threshold
recall_target_idx = np.where(recall_curve >= 0.90)[0][-1]
high_recall_thresh = thresholds[recall_target_idx]

def get_metrics(probs, y_true, thresh):
    preds = (probs > thresh).astype(int)
    cm = confusion_matrix(y_true, preds)
    return {
        "Threshold": f"{thresh:.4f}",
        "Recall": recall_score(y_true, preds),
        "Precision": precision_score(y_true, preds),
        "Accuracy": (preds == y_true.reshape(-1)).mean(),
        "F1": f1_score(y_true, preds),
        "TP": cm[1,1], "FN": cm[1,0], "FP": cm[0,1]
    }

results = [
    get_metrics(final_probs, y_test, 0.5),
    get_metrics(final_probs, y_test, best_f1_thresh),
    get_metrics(final_probs, y_test, high_recall_thresh)
]

df_results = pd.DataFrame(results, index=["Standard (0.5)", "Best F1 Balance", "High Recall (~90%)"])
print("\n=== COMPARATIVE RESULTS ===")
print(df_results.to_string())

# Plot Confusion Matrix for Best F1
best_preds = (final_probs > best_f1_thresh).astype(int)
cm = confusion_matrix(y_test, best_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title(f'Confusion Matrix (Best F1 Threshold: {best_f1_thresh:.4f})')
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.show()