# 03: Model Training - ML Models for HbA1c Estimation

This notebook demonstrates the complete ML training workflow for HbA1c estimation:

1. **Loading** cleaned NHANES data from previous notebooks
2. **Feature Engineering** using raw biomarkers and mechanistic estimators
3. **Stratified Splitting** to balance HbA1c clinical ranges
4. **Training** Ridge, Random Forest, and LightGBM models
5. **Cross-Validation** for robust performance estimation
6. **Model Comparison** and saving best model

---

## Background

We use a hybrid approach that combines:
- Raw biomarker features (FPG, TG, HDL, age, hemoglobin, MCV)
- Ratio features (TG/HDL, FPG-age interaction)
- Mechanistic estimator predictions (ADAG, kinetic, regression)

This allows ML models to learn refinements on top of established clinical relationships.

In [None]:
# Standard library imports
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.insert(0, str(Path.cwd().parent))

# Third-party imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge

# Local imports
from hba1cE.train import (
    create_features,
    stratified_split,
    train_ridge,
    train_random_forest,
    train_lightgbm,
    cross_validate_model,
    save_model,
)

# Configure matplotlib
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11

print("Imports successful!")

---

## Step 1: Load Cleaned Data

Load the cleaned NHANES glycemic data generated in Notebook 01.

In [None]:
# Load cleaned data
DATA_DIR = Path.cwd().parent / "data"
PROCESSED_DIR = DATA_DIR / "processed"

df = pd.read_csv(PROCESSED_DIR / "nhanes_glycemic_cleaned.csv")

print(f"Loaded dataset shape: {df.shape}")
print(f"\nColumns: {list(df.columns)}")
print(f"\nHbA1c range: {df['hba1c_percent'].min():.1f}% - {df['hba1c_percent'].max():.1f}%")
print(f"\nFirst 5 rows:")
df.head()

---

## Step 2: Feature Engineering

Create the feature matrix including:
- Raw biomarker features
- Ratio features  
- Mechanistic estimator predictions

In [None]:
# Create features
X, feature_names = create_features(df)

print(f"Feature matrix shape: {X.shape}")
print(f"\nFeatures ({len(feature_names)} total):")
for i, name in enumerate(feature_names):
    print(f"  {i+1}. {name}")

In [None]:
# Show feature statistics
feature_df = pd.DataFrame(X, columns=feature_names)
print("\nFeature Statistics:")
feature_df.describe().round(2)

---

## Step 3: Stratified Train/Test Split

Split data while maintaining balanced representation across HbA1c clinical ranges:
- <5.7% (normal)
- 5.7-6.4% (prediabetes)
- 6.5-8% (mild diabetes)
- 8-10% (moderate diabetes)
- >10% (severe diabetes)

In [None]:
# Stratified split
X_train, X_test, y_train, y_test = stratified_split(df, test_size=0.3, random_state=42)

print(f"Training set: {X_train.shape[0]} samples ({X_train.shape[0]/len(df)*100:.1f}%)")
print(f"Test set:     {X_test.shape[0]} samples ({X_test.shape[0]/len(df)*100:.1f}%)")

# Show HbA1c distribution in train/test
print(f"\nHbA1c distribution:")
print(f"  Train - mean: {y_train.mean():.2f}%, std: {y_train.std():.2f}%")
print(f"  Test  - mean: {y_test.mean():.2f}%, std: {y_test.std():.2f}%")

---

## Step 4: Train Models

Train three different models:
1. **Ridge Regression** - Linear baseline with L2 regularization
2. **Random Forest** - Ensemble of decision trees for nonlinear patterns
3. **LightGBM** - Gradient boosting for best performance

In [None]:
# Create validation set for LightGBM early stopping
# Use 20% of training data for validation
val_size = int(0.2 * len(X_train))
X_val = X_train[:val_size]
y_val = y_train[:val_size]
X_train_lgb = X_train[val_size:]
y_train_lgb = y_train[val_size:]

print(f"LightGBM validation set: {len(y_val)} samples")
print(f"LightGBM training set: {len(y_train_lgb)} samples")

In [None]:
# Train Ridge Regression
print("Training Ridge Regression...")
ridge_model = train_ridge(X_train, y_train, alpha=1.0)
print(f"  Coefficients: {len(ridge_model.coef_)}")
print(f"  Intercept: {ridge_model.intercept_:.4f}")
print("  Done!")

In [None]:
# Train Random Forest
print("Training Random Forest (200 trees)...")
rf_model = train_random_forest(X_train, y_train, n_estimators=200)
print(f"  Trees: {rf_model.n_estimators}")
print("  Done!")

In [None]:
# Train LightGBM
print("Training LightGBM with early stopping...")
lgb_model = train_lightgbm(
    X_train_lgb, y_train_lgb,
    X_val, y_val,
    n_estimators=1000,
    early_stopping_rounds=20
)
print(f"  Best iteration: {lgb_model.best_iteration_}")
print("  Done!")

---

## Step 5: Cross-Validation

Evaluate each model using 10-fold cross-validation for robust performance estimation.

In [None]:
# Cross-validate all models
print("Running 10-fold cross-validation...\n")

# Ridge CV
print("Ridge Regression:")
ridge_cv = cross_validate_model(Ridge(alpha=1.0), X_train, y_train, n_splits=10)
print(f"  RMSE: {ridge_cv['RMSE_mean']:.4f} ± {ridge_cv['RMSE_std']:.4f}")
print(f"  MAE:  {ridge_cv['MAE_mean']:.4f} ± {ridge_cv['MAE_std']:.4f}")

In [None]:
# Random Forest CV
from sklearn.ensemble import RandomForestRegressor

print("Random Forest:")
rf_cv = cross_validate_model(
    RandomForestRegressor(n_estimators=100, random_state=42),
    X_train, y_train, n_splits=10
)
print(f"  RMSE: {rf_cv['RMSE_mean']:.4f} ± {rf_cv['RMSE_std']:.4f}")
print(f"  MAE:  {rf_cv['MAE_mean']:.4f} ± {rf_cv['MAE_std']:.4f}")

In [None]:
# LightGBM CV (using simpler config for CV)
from lightgbm import LGBMRegressor

print("LightGBM:")
lgb_cv = cross_validate_model(
    LGBMRegressor(n_estimators=100, random_state=42, verbose=-1),
    X_train, y_train, n_splits=10
)
print(f"  RMSE: {lgb_cv['RMSE_mean']:.4f} ± {lgb_cv['RMSE_std']:.4f}")
print(f"  MAE:  {lgb_cv['MAE_mean']:.4f} ± {lgb_cv['MAE_std']:.4f}")

---

## Step 6: Results Comparison

Compare all models in a summary table and visualize results.

In [None]:
# Create comparison table
results_df = pd.DataFrame({
    'Model': ['Ridge Regression', 'Random Forest', 'LightGBM'],
    'RMSE_mean': [ridge_cv['RMSE_mean'], rf_cv['RMSE_mean'], lgb_cv['RMSE_mean']],
    'RMSE_std': [ridge_cv['RMSE_std'], rf_cv['RMSE_std'], lgb_cv['RMSE_std']],
    'MAE_mean': [ridge_cv['MAE_mean'], rf_cv['MAE_mean'], lgb_cv['MAE_mean']],
    'MAE_std': [ridge_cv['MAE_std'], rf_cv['MAE_std'], lgb_cv['MAE_std']],
})

# Add formatted columns
results_df['RMSE'] = results_df.apply(
    lambda r: f"{r['RMSE_mean']:.4f} ± {r['RMSE_std']:.4f}", axis=1
)
results_df['MAE'] = results_df.apply(
    lambda r: f"{r['MAE_mean']:.4f} ± {r['MAE_std']:.4f}", axis=1
)

print("\n" + "="*60)
print("CROSS-VALIDATION RESULTS COMPARISON")
print("="*60)
print(results_df[['Model', 'RMSE', 'MAE']].to_string(index=False))
print("="*60)

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

models = ['Ridge', 'Random Forest', 'LightGBM']
colors = ['#3498db', '#2ecc71', '#e74c3c']

# RMSE comparison
ax1 = axes[0]
rmse_means = [ridge_cv['RMSE_mean'], rf_cv['RMSE_mean'], lgb_cv['RMSE_mean']]
rmse_stds = [ridge_cv['RMSE_std'], rf_cv['RMSE_std'], lgb_cv['RMSE_std']]
bars1 = ax1.bar(models, rmse_means, yerr=rmse_stds, color=colors, 
                edgecolor='black', linewidth=1.2, capsize=5)
ax1.set_ylabel('RMSE (%)')
ax1.set_title('Cross-Validation RMSE Comparison')
ax1.set_ylim(0, max(rmse_means) * 1.3)
for bar, val in zip(bars1, rmse_means):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
             f'{val:.3f}', ha='center', va='bottom', fontsize=11)

# MAE comparison
ax2 = axes[1]
mae_means = [ridge_cv['MAE_mean'], rf_cv['MAE_mean'], lgb_cv['MAE_mean']]
mae_stds = [ridge_cv['MAE_std'], rf_cv['MAE_std'], lgb_cv['MAE_std']]
bars2 = ax2.bar(models, mae_means, yerr=mae_stds, color=colors,
                edgecolor='black', linewidth=1.2, capsize=5)
ax2.set_ylabel('MAE (%)')
ax2.set_title('Cross-Validation MAE Comparison')
ax2.set_ylim(0, max(mae_means) * 1.3)
for bar, val in zip(bars2, mae_means):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
             f'{val:.3f}', ha='center', va='bottom', fontsize=11)

plt.tight_layout()
plt.savefig(DATA_DIR / 'model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nComparison plot saved to: {DATA_DIR / 'model_comparison.png'}")

---

## Step 7: Save Best Model

Identify and save the best-performing model based on RMSE.

In [None]:
# Identify best model
best_idx = results_df['RMSE_mean'].idxmin()
best_model_name = results_df.loc[best_idx, 'Model']
best_rmse = results_df.loc[best_idx, 'RMSE_mean']

print(f"Best model: {best_model_name}")
print(f"RMSE: {best_rmse:.4f}%")

# Select corresponding trained model
if 'Ridge' in best_model_name:
    best_model = ridge_model
elif 'Random Forest' in best_model_name:
    best_model = rf_model
else:
    best_model = lgb_model

In [None]:
# Create models directory and save
MODELS_DIR = Path.cwd().parent / "models"
MODELS_DIR.mkdir(exist_ok=True)

# Save best model
model_filename = f"best_model_{best_model_name.lower().replace(' ', '_')}.joblib"
save_model(best_model, str(MODELS_DIR / model_filename))
print(f"\nBest model saved to: {MODELS_DIR / model_filename}")

# Also save all models for comparison
save_model(ridge_model, str(MODELS_DIR / "ridge_model.joblib"))
save_model(rf_model, str(MODELS_DIR / "random_forest_model.joblib"))
save_model(lgb_model, str(MODELS_DIR / "lightgbm_model.joblib"))
print(f"\nAll models saved to: {MODELS_DIR}")

---

## Step 8: Test Set Evaluation

Final evaluation on held-out test set.

In [None]:
# Predict on test set
y_pred_ridge = ridge_model.predict(X_test)
y_pred_rf = rf_model.predict(X_test)
y_pred_lgb = lgb_model.predict(X_test)

# Calculate test metrics
def calc_metrics(y_true, y_pred):
    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
    mae = np.mean(np.abs(y_true - y_pred))
    return rmse, mae

ridge_rmse, ridge_mae = calc_metrics(y_test, y_pred_ridge)
rf_rmse, rf_mae = calc_metrics(y_test, y_pred_rf)
lgb_rmse, lgb_mae = calc_metrics(y_test, y_pred_lgb)

print("\n" + "="*50)
print("TEST SET PERFORMANCE")
print("="*50)
print(f"{'Model':<20} {'RMSE':>10} {'MAE':>10}")
print("-"*50)
print(f"{'Ridge Regression':<20} {ridge_rmse:>10.4f} {ridge_mae:>10.4f}")
print(f"{'Random Forest':<20} {rf_rmse:>10.4f} {rf_mae:>10.4f}")
print(f"{'LightGBM':<20} {lgb_rmse:>10.4f} {lgb_mae:>10.4f}")
print("="*50)

In [None]:
# Scatter plot: Predicted vs Actual for best model
if 'Ridge' in best_model_name:
    y_pred_best = y_pred_ridge
elif 'Random Forest' in best_model_name:
    y_pred_best = y_pred_rf
else:
    y_pred_best = y_pred_lgb

fig, ax = plt.subplots(figsize=(8, 8))

ax.scatter(y_test, y_pred_best, alpha=0.4, s=15, c='steelblue')
ax.plot([4, 14], [4, 14], 'r--', linewidth=2, label='Perfect prediction')
ax.plot([4, 14], [4.5, 14.5], 'k:', alpha=0.5, label='±0.5% bounds')
ax.plot([4, 14], [3.5, 13.5], 'k:', alpha=0.5)

ax.set_xlabel('Measured HbA1c (%)', fontsize=12)
ax.set_ylabel('Predicted HbA1c (%)', fontsize=12)
ax.set_title(f'Test Set: {best_model_name}\nRMSE = {calc_metrics(y_test, y_pred_best)[0]:.3f}%', fontsize=14)
ax.legend()
ax.set_xlim(4, 14)
ax.set_ylim(4, 14)
ax.set_aspect('equal')

plt.tight_layout()
plt.savefig(DATA_DIR / 'best_model_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nPlot saved to: {DATA_DIR / 'best_model_predictions.png'}")

---

## Summary

This notebook demonstrated the complete ML training workflow for HbA1c estimation:

1. **Loaded** cleaned NHANES glycemic data
2. **Engineered features** including biomarkers, ratios, and mechanistic estimator predictions
3. **Split data** with stratification by HbA1c clinical ranges
4. **Trained** Ridge Regression, Random Forest, and LightGBM models
5. **Cross-validated** all models with 10-fold CV
6. **Compared** model performance and saved the best model

### Key Findings

- All ML models benefit from mechanistic estimator features (hybrid approach)
- Gradient boosting (LightGBM) typically achieves best performance
- Random Forest provides robust nonlinear predictions
- Ridge regression serves as interpretable baseline

### Next Steps

Continue to **Notebook 04: Evaluation** for comprehensive performance analysis including:
- Bland-Altman analysis
- Subgroup evaluation (anemia, age groups)
- Lin's Concordance Correlation Coefficient