# 03 — Model Training

This notebook trains three ML models on NHANES kidney-function data and compares
their performance via 10-fold cross-validation.  The models are:

| Model | Type | Notes |
|-------|------|-------|
| **Ridge** | Linear (L2) | Simple baseline |
| **Random Forest** | Ensemble (bagging) | Captures non-linear patterns |
| **LightGBM** | Gradient boosting | Best expected performance |

The best model (lowest CV RMSE) is saved to `models/` for downstream use.

In [None]:
import os
import sys
import warnings

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# Ensure the project root is on sys.path so that `eGFR` can be imported
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from eGFR.data import read_xpt, clean_kidney_data
from eGFR.train import (
    create_features,
    stratified_split,
    train_ridge,
    train_random_forest,
    train_lightgbm,
    cross_validate_model,
    save_model,
)

warnings.filterwarnings("ignore", category=UserWarning)
print("Imports OK")

## 1 — Load Data

If real NHANES XPT files are available in `data/raw/` we use them.  Otherwise we
fall back to a synthetic dataset that mirrors the NHANES schema so that this
notebook can execute in any environment.

In [None]:
RAW_DIR = os.path.join(PROJECT_ROOT, "data", "raw")


def _try_load_nhanes():
    """Attempt to load real NHANES data from data/raw/."""
    cycles_to_try = [
        ("2017-2018", "_J"),
        ("2015-2016", "_I"),
        ("2013-2014", "_H"),
    ]
    for cycle, suffix in cycles_to_try:
        biopro_path = os.path.join(RAW_DIR, f"BIOPRO{suffix}.XPT")
        demo_path = os.path.join(RAW_DIR, f"DEMO{suffix}.XPT")
        bmx_path = os.path.join(RAW_DIR, f"BMX{suffix}.XPT")
        if all(os.path.isfile(p) for p in [biopro_path, demo_path, bmx_path]):
            try:
                print(f"Loading real NHANES data for cycle {cycle}")
                biopro = read_xpt(biopro_path)
                demo = read_xpt(demo_path)
                bmx = read_xpt(bmx_path)
                return clean_kidney_data(biopro, demo, bmx)
            except (ValueError, Exception) as exc:
                print(f"  Could not parse {cycle} XPT files: {exc}")
                continue
    return None


def _generate_synthetic(n=2000, seed=42):
    """Generate a synthetic clinical dataset mimicking NHANES schema."""
    rng = np.random.default_rng(seed)
    ages = rng.integers(18, 85, size=n).astype(float)
    sexes = rng.choice([1, 2], size=n)  # NHANES coding: 1=M, 2=F
    cr = np.exp(rng.normal(np.log(1.0), 0.35, size=n)).clip(0.3, 12.0)
    weights = rng.normal(80, 15, size=n).clip(40, 160)
    heights = rng.normal(170, 10, size=n).clip(140, 210)

    df = pd.DataFrame({
        "cr_mgdl": cr,
        "age_years": ages,
        "sex": sexes,
        "weight_kg": weights,
        "height_cm": heights,
    })
    print(f"Using synthetic dataset ({n} samples)")
    return df


df = _try_load_nhanes()
if df is None:
    df = _generate_synthetic()

print(f"Dataset shape: {df.shape}")
df.head()

## 2 — Feature Engineering & Stratified Split

In [None]:
X_train, X_test, y_train, y_test = stratified_split(df, test_size=0.3)

print(f"Training set : {X_train.shape[0]:,} samples, {X_train.shape[1]} features")
print(f"Test set     : {X_test.shape[0]:,} samples")
print(f"\nFeature columns: {list(X_train.columns)}")
print(f"\ny target (CKD-EPI 2021 eGFR) — train mean: {y_train.mean():.1f}, "
      f"test mean: {y_test.mean():.1f}")

## 3 — Train Models

We train **Ridge**, **Random Forest**, and **LightGBM** on the training split.

In [None]:
# ----- Ridge -----
ridge_model = train_ridge(X_train, y_train, alpha=1.0)
print("Ridge regression trained.")

# ----- Random Forest -----
rf_model = train_random_forest(X_train, y_train, n_estimators=200)
print("Random Forest trained (200 trees).")

# ----- LightGBM (needs a validation set for early stopping) -----
# We use the held-out test set as the early-stopping validation set.
lgb_model = train_lightgbm(X_train, y_train, X_test, y_test)
print("LightGBM trained with early stopping.")

## 4 — Cross-Validation Comparison

We evaluate each model using 10-fold CV on the **full feature matrix** to get
unbiased performance estimates.

In [None]:
# Build full feature matrix for CV
X_full, feature_names = create_features(df)
y_full = X_full["egfr_ckd_epi_2021"].copy()
valid = y_full.notna()
X_full = X_full.loc[valid]
y_full = y_full.loc[valid]

from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor

models_for_cv = {
    "Ridge": Ridge(alpha=1.0),
    "Random Forest": RandomForestRegressor(n_estimators=200, random_state=42),
}

cv_results = {}

for name, model in models_for_cv.items():
    print(f"Running 10-fold CV for {name}...")
    cv_results[name] = cross_validate_model(model, X_full, y_full, n_splits=10)

# LightGBM cannot be cloned trivially via sklearn.base.clone because of
# early stopping; we run CV manually with a fresh LGBMRegressor each fold.
from lightgbm import LGBMRegressor

print("Running 10-fold CV for LightGBM...")
lgb_base = LGBMRegressor(
    n_estimators=500,
    learning_rate=0.05,
    random_state=42,
    verbosity=-1,
)
cv_results["LightGBM"] = cross_validate_model(lgb_base, X_full, y_full, n_splits=10)

print("\nCross-validation complete.")

In [None]:
# Build comparison table
rows = []
for name, metrics in cv_results.items():
    rows.append({
        "Model": name,
        "RMSE (mean)": f"{metrics['RMSE_mean']:.2f}",
        "RMSE (std)": f"{metrics['RMSE_std']:.2f}",
        "MAE (mean)": f"{metrics['MAE_mean']:.2f}",
        "MAE (std)": f"{metrics['MAE_std']:.2f}",
    })

cv_table = pd.DataFrame(rows).set_index("Model")
print("\n=== 10-Fold Cross-Validation Results ===")
print(cv_table.to_string())

## 5 — Select & Save Best Model

In [None]:
# Determine best model by lowest CV RMSE
best_name = min(cv_results, key=lambda n: cv_results[n]["RMSE_mean"])
best_rmse = cv_results[best_name]["RMSE_mean"]
print(f"Best model: {best_name} (CV RMSE = {best_rmse:.2f})")

# Map name -> fitted model object
trained_models = {
    "Ridge": ridge_model,
    "Random Forest": rf_model,
    "LightGBM": lgb_model,
}
best_model = trained_models[best_name]

# Save best model
MODEL_DIR = os.path.join(PROJECT_ROOT, "models")
os.makedirs(MODEL_DIR, exist_ok=True)
best_path = os.path.join(MODEL_DIR, "best_model.joblib")
save_model(best_model, best_path)
print(f"Best model saved to: {best_path}")

# Also save all models individually for reference
for name, model in trained_models.items():
    slug = name.lower().replace(" ", "_")
    path = os.path.join(MODEL_DIR, f"{slug}.joblib")
    save_model(model, path)
    print(f"  {name} -> {path}")

## 6 — Test-Set Predictions (Quick Sanity Check)

We generate predictions on the held-out test set and show a scatter plot for the
best model.

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

y_pred_test = best_model.predict(X_test)
rmse_test = np.sqrt(mean_squared_error(y_test, y_pred_test))
mae_test = mean_absolute_error(y_test, y_pred_test)
r2_test = r2_score(y_test, y_pred_test)

print(f"{best_name} — Test-set performance:")
print(f"  RMSE : {rmse_test:.2f}")
print(f"  MAE  : {mae_test:.2f}")
print(f"  R2   : {r2_test:.4f}")

fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(y_test, y_pred_test, alpha=0.3, s=10, edgecolors="none")
lims = [0, max(y_test.max(), np.max(y_pred_test)) * 1.05]
ax.plot(lims, lims, "--", color="red", linewidth=1, label="Identity")
ax.set_xlabel("Actual eGFR (CKD-EPI 2021)")
ax.set_ylabel("Predicted eGFR")
ax.set_title(f"{best_name} — Predicted vs Actual")
ax.legend()
ax.set_xlim(lims)
ax.set_ylim(lims)
plt.tight_layout()
plt.savefig(os.path.join(PROJECT_ROOT, "models", "best_model_scatter.png"), dpi=150)
plt.show()
print("Done.")

## Summary

- **Ridge**, **Random Forest**, and **LightGBM** models were trained on the
  available dataset.
- 10-fold cross-validation was used to compare model performance.
- The best model (by CV RMSE) was saved to `models/best_model.joblib`.
- All individual models were also saved to `models/` for reference.

> **Note:** When using synthetic data, the models learn to reproduce the CKD-EPI
> 2021 equation from the engineered features.  True model utility will be
> assessed when real NHANES data with external measured-GFR validation is
> available (see notebooks 04 and 05).