# Notes 2 - Predictors of recorded DS live births

Next:

- Review data to see if we can go back further - check race
- Hyperparameter tuning
- Consider bagging vs GOSS


## Preparation


In [None]:
import duckdb, joblib, optuna, os, shap
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime
from lightgbm import early_stopping, log_evaluation
from scipy.spatial.distance import squareform
from scipy.cluster import hierarchy
from sklearn.inspection import permutation_importance
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    average_precision_score,
    roc_auc_score,
    brier_score_loss,
    log_loss,
    precision_recall_fscore_support,
    roc_curve,
)
from sklearn.model_selection import train_test_split

from variables import Variables as vars
import repl_utils, stats_utils

pd.options.mode.copy_on_write = True

plt.style.use("../../notebook.mplstyle")

os.makedirs("./output", exist_ok=True)

RANDOM_SEED = repl_utils.RANDOM_SEED
np.random.seed(RANDOM_SEED)
rng = np.random.default_rng(RANDOM_SEED)

N_CORES = joblib.cpu_count(only_physical_cores=True)
START_TIME = datetime.now()
OUTPUT_DIR = f"output/0002-predictors/{START_TIME:%Y%m%d-%H%M%S}"

SAVE_PLOTS = True

os.makedirs(OUTPUT_DIR, exist_ok=True)

repl_utils.print_environment_info()

print(f"\n--------------------\nOutput directory: {OUTPUT_DIR}\n--------------------\n")

### Options


In [None]:
START_YEAR = 2005
# LightGBM threads
NUM_THREADS = N_CORES
# Splitting data for training, validation and calibration
TRAINING_SPLIT = 0.5
VALIDATION_SPLIT = 0.25
CALIBRATION_SPLIT = 1 - TRAINING_SPLIT - VALIDATION_SPLIT
#
NUM_BOOST_ROUND = 10000
EARLY_STOPPING_ROUNDS = 50
# True to search for hyperparameters
SELECT_HYPERPARAMETERS = True
#
OPTIMIZE_TRIALS = 50

### Load data


In [None]:
con = duckdb.connect("./data/us_births.db", read_only=True)

df = con.execute(
    f"""
    SELECT
        -- (training label) indicated if C or P, not indicated if N, U and missing excluded from training
        CASE
            WHEN COALESCE (ca_down, ca_downs) = 'C' THEN 1::UTINYINT
            WHEN COALESCE (ca_down, ca_downs) = 'P' THEN 1::UTINYINT
            WHEN COALESCE (ca_down, ca_downs) = 'N' THEN 0::UTINYINT
            WHEN uca_downs = 1 THEN 1::UTINYINT
            WHEN uca_downs = 2 THEN 0::UTINYINT
            ELSE NULL
        END AS ca_down_c_p_n,        
        -- ==================== date of birth ====================
        year,
        -- month of birth
        dob_mm,
        -- day of week of birth (1-7)
        dob_wk,
        -- time of birth (0: AM, 1: PM, 2: not stated)
        CASE
            WHEN dob_tt >= 0 AND dob_tt <= 1159 THEN 0::UTINYINT
            WHEN dob_tt >= 1200 AND dob_tt <= 2359 THEN 1::UTINYINT
            WHEN dob_tt = 9999 THEN 2::UTINYINT
            ELSE NULL
        END AS dob_tt_pm,
        -- ==================== birth location ====================
        -- birth place (1: hospital, 2: not hospital, 3: unknown/not stated)
        bfacil3,
        -- ==================== characteristics of baby ====================
        -- sex of baby
        CASE
            WHEN sex = 'M' THEN 1::UTINYINT
            WHEN sex = 'F' THEN 0::UTINYINT
            ELSE NULL
        END AS sex,
        -- birth weight (grams)
        CASE
            WHEN dbwt >= 227 AND dbwt <= 8165 THEN dbwt
            -- we ignore "Not stated" as we treat this variable as numeric rather than categorical
            ELSE NULL
        END AS dbwt,
        -- ==================== characteristics of pregnancy ====================
        -- plurality (1: single... 4 quadpruplet or higher)
        dplural,
        -- month prenatal care began (1 to 10, 0: no prenatal care)
        CASE
            WHEN precare >= 0 AND precare <= 10 THEN precare
            WHEN precare = 99 THEN precare
            ELSE NULL
        END AS precare,
        -- combined gestation estimate
        CASE
            WHEN gestrec10 >= 1 AND gestrec10 <= 10 THEN gestrec10
            WHEN gestrec10 = 99 THEN gestrec10
            ELSE NULL
        END AS gestrec10,
        -- pre-pregnancy weight recode (in pounds)
        CASE
            WHEN pwgt_r >= 75 AND pwgt_r <= 375 THEN pwgt_r
            -- we ignore "Unknown or not stated" as we treat this variable as numeric rather than categorical
            ELSE NULL
        END AS pwgt_r,
        -- weight gain in pounds (98 = 98+)
        CASE
            WHEN wtgain >= 0 AND wtgain <= 98 THEN wtgain
            -- we ignore "Unknown or not stated" as we treat this variable as numeric rather than categorical
            ELSE NULL
        END AS wtgain,
        -- maternal body mass index
        CASE
            WHEN bmi >= 13.0 AND bmi < 69.9 THEN bmi
            -- we ignore "Unknown or not stated" as we treat this variable as numeric rather than categorical
            ELSE NULL
        END
        AS bmi,
        -- ==================== pregnancy risk factors ====================
        -- pre-pregnancy diabetes
        CASE
            WHEN rf_pdiab = 'Y' THEN 1::UTINYINT
            WHEN rf_pdiab = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_pdiab,
        -- gestational diabetes
        CASE
            WHEN rf_gdiab = 'Y' THEN 1::UTINYINT
            WHEN rf_gdiab = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_gdiab,
        -- pre-pregnancy hypertension
        CASE
            WHEN rf_phype = 'Y' THEN 1::UTINYINT
            WHEN rf_phype = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_phype,
        -- gestational hypertension
        CASE
            WHEN rf_ghype = 'Y' THEN 1::UTINYINT
            WHEN rf_ghype = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_ghype,
        -- hypertension eclampsia
        CASE
            WHEN rf_ehype = 'Y' THEN 1::UTINYINT
            WHEN rf_ehype = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_ehype,
        -- previous preterm birth
        CASE
            WHEN rf_ppterm = 'Y' THEN 1::UTINYINT
            WHEN rf_ppterm = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_ppterm,
        -- infertility treatment used
        CASE
            WHEN rf_inftr = 'Y' THEN 1::UTINYINT
            WHEN rf_inftr = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_inftr,
        -- fertility enhancing drugs
        CASE
            WHEN rf_fedrg = 'Y' THEN 1::UTINYINT
            WHEN rf_fedrg = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_fedrg,
        -- asst. reproductive technology
        CASE
            WHEN rf_artec = 'Y' THEN 1::UTINYINT
            WHEN rf_artec = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS rf_artec,
        -- no risk factors reported
        CASE
            WHEN no_risks <= 1 THEN no_risks
            ELSE NULL
        END AS no_risks,
        -- ==================== labor and delivery ====================
        -- induction of labor
        CASE
            WHEN ld_indl = 'Y' THEN 1::UTINYINT
            WHEN ld_indl = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ld_indl,
        -- augmentation of labor
        CASE
            WHEN ld_augm = 'Y' THEN 1::UTINYINT
            WHEN ld_augm = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ld_augm,
        -- fetal presentation at delivery
        CASE
            WHEN me_pres >= 1 AND me_pres <= 3 THEN me_pres
            WHEN me_pres = 9 THEN me_pres
            ELSE NULL
        END AS me_pres,
        -- delivery method recode
        CASE
            WHEN dmeth_rec >= 1 AND dmeth_rec <= 2 THEN dmeth_rec
            WHEN dmeth_rec = 9 THEN dmeth_rec
            ELSE NULL
        END AS dmeth_rec,
        -- ==================== newborn health ====================
        -- five minute apgar score
        CASE
            WHEN apgar5 >= 10 AND apgar5 <= 10 THEN apgar5
            WHEN apgar5 = 99 THEN apgar5
            ELSE NULL
        END AS apgar5,
        -- ten minute apgar score
        CASE
            WHEN apgar10 >= 10 AND apgar10 <= 10 THEN apgar10
            WHEN apgar10 = 99 THEN apgar10
            ELSE NULL
        END AS apgar10,
        -- assisted ventilation (immediately)
        CASE
            WHEN ab_aven1 = 'Y' THEN 1::UTINYINT
            WHEN ab_aven1 = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ab_aven1,
        -- assisted ventilation > 6 hrs
        CASE
            WHEN ab_aven6 = 'Y' THEN 1::UTINYINT
            WHEN ab_aven6 = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ab_aven6,
        -- admitted to nicu
        CASE
            WHEN ab_nicu = 'Y' THEN 1::UTINYINT
            WHEN ab_nicu = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ab_nicu,
        -- surfactant
        CASE
            WHEN ab_surf = 'Y' THEN 1::UTINYINT
            WHEN ab_surf = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ab_surf,
        -- antibiotics for newborn
        CASE
            WHEN ab_anti = 'Y' THEN 1::UTINYINT
            WHEN ab_anti = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ab_anti,
        -- seizures
        CASE
            WHEN ab_seiz = 'Y' THEN 1::UTINYINT
            WHEN ab_seiz = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ab_seiz,
        -- no_abnorm
        CASE
            WHEN no_abnorm >= 0 AND no_abnorm <= 1 THEN no_abnorm
            WHEN no_abnorm = 9 THEN no_abnorm
            ELSE NULL
        END AS no_abnorm,
        -- ==================== identified disorders ====================
        -- congenital disorder
        CASE
            WHEN ca_disor = 'C' THEN 1::UTINYINT
            WHEN ca_disor = 'P' THEN 2::UTINYINT
            WHEN ca_disor = 'N' THEN 0::UTINYINT
            ELSE NULL
        END
        AS ca_disor,
        -- anencephaly
        CASE
            WHEN ca_anen = 'Y' THEN 1::UTINYINT
            WHEN ca_anen = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_anen,
        -- meningomyelocele / spina bifida
        CASE
            WHEN ca_mnsb = 'Y' THEN 1::UTINYINT
            WHEN ca_mnsb = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_mnsb,
        -- congenital heart defect
        CASE
            WHEN ca_cchd = 'Y' THEN 1::UTINYINT
            WHEN ca_cchd = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_cchd,
        -- ca_cdh
        CASE
            WHEN ca_cdh = 'Y' THEN 1::UTINYINT
            WHEN ca_cdh = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_cdh,
        -- omphalocele
        CASE
            WHEN ca_omph = 'Y' THEN 1::UTINYINT
            WHEN ca_omph = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_omph,
        -- gastroschisis
        CASE
            WHEN ca_gast = 'Y' THEN 1::UTINYINT
            WHEN ca_gast = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_gast,
        -- limb reduction defect
        CASE
            WHEN ca_limb = 'Y' THEN 1::UTINYINT
            WHEN ca_limb = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_limb,
        -- cleft lip w/ or w/o cleft palate
        CASE
            WHEN ca_cleft = 'Y' THEN 1::UTINYINT
            WHEN ca_cleft = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_cleft,
        -- cleft palate alone
        CASE
            WHEN ca_clpal = 'Y' THEN 1::UTINYINT
            WHEN ca_clpal = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_clpal,
        -- Hypospadias
        CASE
            WHEN ca_hypo = 'Y' THEN 1::UTINYINT
            WHEN ca_hypo = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS ca_hypo,
        -- suspected chromosomal disorder
        CASE
            WHEN ca_disor = 'C' THEN 1::UTINYINT
            WHEN ca_disor = 'P' THEN 2::UTINYINT
            WHEN ca_disor = 'N' THEN 0::UTINYINT
            WHEN ca_disor = 'U' THEN 9::UTINYINT
            ELSE NULL
        END AS ca_disor,
        -- no_congen
        CASE
            WHEN no_congen >= 0 AND no_congen <= 1 THEN no_congen
            WHEN no_congen = 9 THEN no_congen
            ELSE NULL
        END AS no_congen,
        -- ==================== maternal characteristics ====================
        -- maternal age in years
        mage_c,
        -- maternal education
        CASE
            WHEN meduc >= 0 AND meduc < 10 THEN meduc
            ELSE NULL
        END AS meduc,
        -- maternal race
        CASE
            WHEN  mracehisp >= 1 AND mracehisp <= 8 THEN mracehisp
            ELSE NULL
        END AS mracehisp,
        -- ==================== paternal characteristics ====================
        -- father's combined age in years
        CASE
            WHEN fagecomb >= 9 AND fagecomb < 99 THEN fagecomb
            ELSE NULL
        END AS fagecomb,
        -- paternal education
        CASE
            WHEN  feduc < 9 THEN feduc
            ELSE NULL
        END AS feduc,
        -- paternal race
        CASE
            WHEN  fracehisp >= 1 AND fracehisp <= 10 THEN fracehisp
            ELSE NULL
        END AS fracehisp,
        -- ==================== socio-economic indicators ====================
        -- payment source recode
        CASE
            WHEN  pay_rec < 5 THEN pay_rec
            ELSE NULL
        END AS pay_rec,
        -- supplemental nutrition program for women, infants, and children
        CASE
            WHEN wic = 'Y' THEN 1::UTINYINT
            WHEN wic = 'N' THEN 0::UTINYINT
            ELSE NULL
        END AS wic
    FROM
        us_births
    WHERE year >= {START_YEAR} AND ca_down_c_p_n IS NOT NULL
    ORDER BY
        year, dob_mm, dob_wk
    """
).df()

con.close()

df.describe().T

In [None]:
def ap_scorer(estimator, X, y):
    """
    Average precision scorer for sklearn's cross-validation and hyperparameter tuning utilities.
    """
    proba = estimator.predict_proba(X)[:, 1]
    return average_precision_score(y, proba)


class LGBMEstimator:
    """
    A wrapper for a LightGBM booster to provide sklearn-like interface.
    This is needed because we train the LightGBM model using its native API,
    but we want to use it with sklearn utilities like permutation importance.
    """

    def __init__(self, booster, threshold=0.5):
        self.booster = booster
        self.threshold = threshold

    def fit(self, X, y=None):
        return self

    def _predict_p1(self, X):
        # Use the early-stopped model size: if we omit num_iteration=best_iteration, then permutation
        # importance might be computed using the full num_boost_round trees (or whatever the booster
        # currently holds),
        num_iter = getattr(self.booster, "best_iteration", None)
        if num_iter is None or num_iter == 0:
            return self.booster.predict(X)
        return self.booster.predict(X, num_iteration=num_iter)

    # ap_scorer calls predict_proba(),
    def predict_proba(self, X):
        p1 = self._predict_p1(X)
        p0 = 1.0 - p1
        return np.column_stack([p0, p1])

    def predict(self, X):
        p1 = self._predict_p1(X)
        return (p1 >= self.threshold).astype(int)

In [None]:
def group_permutation_importance(
    estimator,
    X: pd.DataFrame,
    y: pd.Series,
    groups: dict[str, list[str]],
    scorer=average_precision_score,
    n_repeats: int = 5,
    random_state: int = 0,
    use_predict_proba: bool = True,
):
    """
    Compute group permutation importance.

    Parameters
    ----------
    estimator : object
        Must implement predict_proba(X) -> (n,2) or predict(X) -> (n,).
        Your LGBMWrapper works.
    X : DataFrame
        Evaluation data.
    y : Series/array
        Labels (0/1).
    groups : dict
        Mapping group_name -> list of column names to permute together.
    scorer : callable
        For AP, pass average_precision_score.
    n_repeats : int
        Permutation repeats per group.
    random_state : int
        Seed.
    use_predict_proba : bool
        If True, scorer uses predict_proba[:,1], else uses predict.

    Returns
    -------
    DataFrame with mean/std importance per group (higher = more important).
    Importance is measured as decrease in score when permuted: (baseline - permuted).
    """
    rng = np.random.default_rng(random_state)

    # baseline score
    if use_predict_proba:
        p = estimator.predict_proba(X)[:, 1]
        baseline = scorer(y, p)
    else:
        pred = estimator.predict(X)
        baseline = scorer(y, pred)

    results = []
    X_work = X.copy()

    for gname, cols in groups.items():
        cols = [c for c in cols if c in X.columns]
        if len(cols) == 0:
            continue

        drops = []
        for _ in range(n_repeats):
            # Permute rows consistently across the whole group
            perm = rng.permutation(len(X_work))

            # Apply permutation to each column in the group
            X_perm = X_work.copy()
            for c in cols:
                X_perm[c] = X_work[c].to_numpy()[perm]

            if use_predict_proba:
                p_perm = estimator.predict_proba(X_perm)[:, 1]
                score_perm = scorer(y, p_perm)
            else:
                pred_perm = estimator.predict(X_perm)
                score_perm = scorer(y, pred_perm)

            drops.append(baseline - score_perm)

        results.append(
            {
                "group": gname,
                "n_features": len(cols),
                "features": cols,
                "baseline_score": baseline,
                "importance_mean": float(np.mean(drops)),
                "importance_std": (
                    float(np.std(drops, ddof=1)) if n_repeats > 1 else 0.0
                ),
            }
        )

    out = (
        pd.DataFrame(results)
        .sort_values("importance_mean", ascending=False)
        .reset_index(drop=True)
    )
    return out

In [None]:
def score_metrics(y_true, p_valid):
    """
    Compute validation metrics: AUC, AP, log loss, ROC curve.

    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True binary labels.
    p_valid : array-like of shape (n_samples,)
        Predicted probabilities or scores.

    Returns
    -------
    p_valid_auc : float
        Area Under the ROC Curve.
    p_valid_ap : float
        Average Precision score.
    p_valid_ll : float
        Log loss.
    p_valid_fpr : array-like of shape (n_thresholds,)
        False Positive Rates for ROC curve.
    p_valid_tpr : array-like of shape (n_thresholds,)
        True Positive Rates for ROC curve.
    p_valid_thresholds : array-like of shape (n_thresholds,)
        Thresholds used to compute ROC curve.
    """
    p_valid_auc = roc_auc_score(y_true, p_valid)
    p_valid_ap = average_precision_score(y_true, p_valid)
    p_valid_ll = log_loss(y_true, p_valid, labels=[0, 1])
    p_valid_fpr, p_valid_tpr, p_valid_thresholds = roc_curve(y_true, p_valid)
    return (
        p_valid_auc,
        p_valid_ap,
        p_valid_ll,
        p_valid_fpr,
        p_valid_tpr,
        p_valid_thresholds,
    )


def precision_recall_at_k(y_true, p_valid, K: int = 10000):
    """
    Compute precision and recall at top K predictions.

    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True binary labels.
    p_valid : array-like of shape (n_samples,)
        Predicted probabilities or scores.
    K : int
        Number of top predictions to consider.

    Returns
    -------
    precision_at_k : float
        Precision at top K predictions.
    recall_at_k : float
        Recall at top K predictions.
    """
    order = np.argsort(-p_valid)
    y_sorted = y_true.to_numpy()[order]
    precision_at_k = y_sorted[:K].mean()
    recall_at_k = y_sorted[:K].sum() / y_true.sum()
    return precision_at_k, recall_at_k


def precision_recall_at_threshold(y_true, p_valid, thr: float = 0.01):
    """
    Compute precision and recall at a given threshold.

    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True binary labels.
    p_valid : array-like of shape (n_samples,)
        Predicted probabilities or scores.
    thr : float
        Threshold for converting predicted probabilities to binary predictions.

    Returns
    -------
    prec : float
        Precision at the given threshold.
    rec : float
        Recall at the given threshold.
    f1 : float
        F1-score at the given threshold.
    """
    y_hat = (p_valid >= thr).astype(int)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_hat, average="binary")
    return prec, rec, f1


def get_metrics(y_true, p_valid, K: int = 10000, thr: float = 0.01):
    """
    Build a DataFrame of validation metrics.

    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True binary labels.
    p_valid : array-like of shape (n_samples,)
        Predicted probabilities or scores.
    K : int
        Number of top predictions to consider for precision/recall at K.
    thr : float
        Threshold for precision/recall calculation.
    Returns
    -------
    metrics_df : pd.DataFrame
        DataFrame containing validation metrics.
    """

    (
        p_valid_auc,
        p_valid_ap,
        p_valid_ll,
        p_valid_fpr,
        p_valid_tpr,
        p_valid_thresholds,
    ) = score_metrics(y_true, p_valid)

    precision_at_k, recall_at_k = precision_recall_at_k(y_true, p_valid, K=K)

    prec, rec, f1 = precision_recall_at_threshold(y_true, p_valid, thr=thr)

    df = pd.DataFrame(
        {
            "metric": [
                "Validation AUC",
                "Validation AP",
                "Validation log loss",
                f"Precision at {K}",
                f"Recall at {K}",
                f"Precision (threshold={thr})",
                f"Recall (threshold={thr})",
            ],
            "value": [
                p_valid_auc,
                p_valid_ap,
                p_valid_ll,
                precision_at_k,
                recall_at_k,
                prec,
                rec,
            ],
        }
    )

    return df, p_valid_fpr, p_valid_tpr, p_valid_thresholds

In [None]:
def plot_roc_curve(fpr, tpr, model_idx: int, save: bool = False):
    plt.figure(figsize=(4, 4))
    plt.plot(fpr, tpr, label="ROC curve")
    plt.plot([0, 1.0], [0, 1], "--", color="#999999", label="Random classifier")
    plt.xlim([-0.03, 1.03])
    plt.ylim([0, 1.03])
    plt.xlabel("False Positive Rate (FPR)")
    plt.ylabel("True Positive Rate (TPR)")
    plt.title(f"Model {model_idx}: Receiver Operating Characteristic (ROC) Curve")
    plt.legend(loc="lower right")
    if save:
        plt.savefig(
            f"{OUTPUT_DIR}/model_{model_idx}_roc_curve_{datetime.now().strftime('%Y%m%d%H%M')}.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.savefig(
            f"{OUTPUT_DIR}/model_{model_idx}_roc_curve_{datetime.now().strftime('%Y%m%d%H%M')}.svg",
            bbox_inches="tight",
        )
    plt.show()


def plot_precision_recall_curve(fpr, tpr, model_idx: int, save: bool = False):
    plt.figure(figsize=(4, 4))
    plt.plot(fpr, tpr, label=f"Precision-Recall curve")
    plt.xlim([-0.03, 1.03])
    plt.ylim([0, 1.03])
    plt.xlabel("Recall (TP / (TP + FN))")
    plt.ylabel("Precision (TP / (TP + FP))")
    plt.title(f"Model {model_idx}: Precision-Recall Curve")
    plt.legend(loc="lower right")
    if save:
        plt.savefig(
            f"{OUTPUT_DIR}/model_{model_idx}_precision_recall_curve_{datetime.now().strftime('%Y%m%d%H%M')}.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.savefig(
            f"{OUTPUT_DIR}/model_{model_idx}_precision_recall_curve_{datetime.now().strftime('%Y%m%d%H%M')}.svg",
            bbox_inches="tight",
        )
    plt.show()


def plot_permutation_importances(result, X_eval, model_idx: int, save: bool = False):
    sorted_importances_idx = result.importances_mean.argsort()

    importances = pd.DataFrame(
        result.importances[sorted_importances_idx].T,
        columns=X_eval.columns[sorted_importances_idx],
    )
    x_size = max(4, min(6, 0.25 * importances.shape[1]))
    ax = importances.plot.box(vert=False, whis=10, figsize=(6, x_size))
    ax.set_title(f"Model {model_idx}: Permutation importances")
    ax.axvline(x=0, color="k", linestyle="--")
    ax.set_xlabel("Decrease in accuracy score")
    ax.set_ylabel("Predictor variable")

    if save:
        plt.savefig(
            f"{OUTPUT_DIR}/model_{model_idx}_feature_importance_{datetime.now().strftime('%Y%m%d%H%M')}.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.savefig(
            f"{OUTPUT_DIR}/model_{model_idx}_feature_importance_{datetime.now().strftime('%Y%m%d%H%M')}.svg",
            bbox_inches="tight",
        )
    plt.show()

### Define initial feature set


In [None]:
numeric = [
    "year",
    "dbwt",
    "pwgt_r",
    "wtgain",
    "bmi",
    "mage_c",
    "fagecomb",
]

categorical = [
    "dob_mm",
    "dob_wk",
    "dob_tt_pm",
    "bfacil3",
    "sex",
    "precare",
    "gestrec10",
    "rf_pdiab",
    "rf_gdiab",
    "rf_phype",
    "rf_ghype",
    "rf_ehype",
    "rf_ppterm",
    "rf_inftr",
    "rf_fedrg",
    "rf_artec",
    "no_risks",
    "ld_indl",
    "ld_augm",
    "me_pres",
    "dmeth_rec",
    "apgar5",
    "apgar10",
    "ab_aven1",
    "ab_aven6",
    "ab_nicu",
    "ab_surf",
    "ab_anti",
    "ab_seiz",
    "no_abnorm",
    "ca_anen",
    "ca_mnsb",
    "ca_cchd",
    "ca_cdh",
    "ca_omph",
    "ca_gast",
    "ca_limb",
    "ca_cleft",
    "ca_clpal",
    "ca_hypo",
    "ca_disor",
    "no_congen",
    "meduc",
    "mracehisp",
    "feduc",
    "fracehisp",
    "pay_rec",
    "wic",
]

features = categorical + numeric

X = df[features]
y = df["ca_down_c_p_n"]

X[categorical] = X[categorical].astype("category")

### Split training, validation and calibration data


In [None]:
# use half the data for the training set
X_train, X_tmp, y_train, y_tmp = train_test_split(
    X, y, test_size=TRAINING_SPLIT, stratify=y, random_state=RANDOM_SEED
)

# use a quarter of the data for training validation and a quarter for calibration
X_valid, X_cal, y_valid, y_cal = train_test_split(
    X_tmp,
    y_tmp,
    test_size=(VALIDATION_SPLIT / TRAINING_SPLIT),
    stratify=y_tmp,
    random_state=RANDOM_SEED,
)

train_data = lgb.Dataset(
    X_train, label=y_train, categorical_feature=categorical, free_raw_data=False
)

valid_data = lgb.Dataset(
    X_valid,
    label=y_valid,
    categorical_feature=categorical,
    reference=train_data,
    free_raw_data=False,
)

In [None]:
neg_count_train = y_train[y_train == 0].count()
pos_count_train = y_train[y_train == 1].count()
# scale_pos_weight = neg_count_train / pos_count_train

neg_count_valid = y_valid[y_valid == 0].count()
pos_count_valid = y_valid[y_valid == 1].count()

neg_count_cal = y_cal[y_cal == 0].count()
pos_count_cal = y_cal[y_cal == 1].count()

print(
    f"Training set: {neg_count_train} negatives, {pos_count_train} positives, probability positive {pos_count_train / neg_count_train:.8f}"
)
print(
    f"Validation set: {neg_count_valid} negatives, {pos_count_valid} positives, probability positive {pos_count_valid / neg_count_valid:.8f}"
)
print(
    f"Calibration set: {neg_count_cal} negatives, {pos_count_cal} positives, probability positive {pos_count_cal / neg_count_cal:.8f}"
)

In [None]:
base_params = {
    "objective": "binary",
    "metric": ["average_precision", "binary_logloss"],
    "boosting_type": "gbdt",
    "max_bin": 255,  # GPU 63/127; CPU 255
    # for now, we do not scale for better interpretability of outputs
    "scale_pos_weight": 1,
    "force_col_wise": True,
    "seed": RANDOM_SEED,
    "num_threads": NUM_THREADS,
    "verbosity": 1,
}

last_best_params = {
    "learning_rate": 0.01710687967617717,
    "num_leaves": 158,
    "min_data_in_leaf": 1128,
    "min_gain_to_split": 0.464491915106811,
    "feature_fraction": 0.8526051891727362,
    "bagging_fraction": 0.883228962083461,
    "bagging_freq": 6,
    "lambda_l1": 2.8601599535567946e-06,
    "lambda_l2": 8.149358490508543,
}

## Hyperparameter tuning


In [None]:
def objective(trial):
    trial_params = {
        # required to change min_data_in_leaf across trials without rebuilding the Dataset
        "feature_pre_filter": False,
        # Speed / stability
        "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.5, log=True),
        # Tree complexity
        "num_leaves": trial.suggest_int("num_leaves", 32, 512, log=True),
        "min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 500, 10000, log=True),
        "min_gain_to_split": trial.suggest_float("min_gain_to_split", 0.0, 1.0),
        # Sampling
        "feature_fraction": trial.suggest_float("feature_fraction", 0.6, 1.0),
        "bagging_fraction": trial.suggest_float("bagging_fraction", 0.6, 1.0),
        "bagging_freq": trial.suggest_int("bagging_freq", 1, 10),
        # Regularization
        "lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True),
        "lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),
    }

    # Merge: base_params always present; trial_params override if same key exists
    params = {**base_params, **trial_params}

    pruning_cb = optuna.integration.LightGBMPruningCallback(trial, "average_precision")

    gbm = lgb.train(
        params,
        train_data,
        num_boost_round=NUM_BOOST_ROUND,
        valid_sets=[valid_data],
        callbacks=[
            lgb.early_stopping(stopping_rounds=EARLY_STOPPING_ROUNDS),
            lgb.log_evaluation(period=10),
            pruning_cb,
        ],
    )

    # Best score on validation
    return gbm.best_score["valid_0"]["average_precision"]


if SELECT_HYPERPARAMETERS:
    study = optuna.create_study(
        direction="maximize",
        sampler=optuna.samplers.TPESampler(),
        pruner=optuna.pruners.HyperbandPruner(),
    )

    study.optimize(objective, n_trials=OPTIMIZE_TRIALS)

    print(study.best_params, study.best_value)

    best = study.best_params
else:
    best = last_best_params

# Merge (best overrides base if there are collisions)
params = {**base_params, **best}
params["feature_pre_filter"] = True  # reset to default for final training

print("Parameters for training:")
for k, v in params.items():
    print(f'  "{k}": {v}')

## Model 0 (all initial predictors)


In [None]:
model_idx = 0
model_name = f"Model {model_idx} (all initial predictors)"

gbm = lgb.train(
    params,
    train_data,
    num_boost_round=NUM_BOOST_ROUND,
    valid_sets=[train_data, valid_data],
    valid_names=["train", "valid"],
    callbacks=[
        early_stopping(stopping_rounds=EARLY_STOPPING_ROUNDS),
        log_evaluation(period=10),
    ],
)

best_iter = gbm.best_iteration
gbm.save_model(
    f"{OUTPUT_DIR}/model_{model_idx}_{datetime.now().strftime("%Y%m%d%H%M")}.txt",
    num_iteration=best_iter,
)

### Predict

Produce predictions on validation set


In [None]:
p_valid = gbm.predict(X_valid, num_iteration=best_iter)

### Evaluate predictions


In [None]:
metrics_df, p_valid_fpr, p_valid_tpr, p_valid_thresholds = get_metrics(
    y_valid, p_valid, K=10000, thr=0.01
)

metrics_df.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_validation_metrics_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)

metrics_df

In [None]:
plot_roc_curve(p_valid_fpr, p_valid_tpr, model_idx, save=SAVE_PLOTS)

In [None]:
plot_precision_recall_curve(p_valid_fpr, p_valid_tpr, model_idx, save=SAVE_PLOTS)

#### Feature importance


In [None]:
importance_gain = gbm.feature_importance(importance_type="gain")

df_imp_gain = pd.DataFrame(
    {"feature": features, "importance_gain": importance_gain}
).sort_values("importance_gain", ascending=False)

df_imp_gain.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_feature_importance_gain_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
df_imp_gain

In [None]:
def build_explain_set(
    booster,
    X_valid,
    y_valid,
    categorical,
    n_neg_rand=100_000,
    n_neg_hard=100_000,
    seed=RANDOM_SEED,
):
    """
    Build a validation set for explanation by combining all positives,
    a random sample of negatives, and a sample of hard negatives (highest predicted
    probabilities among negatives).

    Parameters
    ----------
    booster : lightgbm.Booster
        Trained LightGBM booster.
    X_valid : pd.DataFrame
        Validation feature set.
    y_valid : pd.Series
        Validation target values.
    categorical : list of str
        List of categorical feature names.
    n_neg_rand : int
        Number of random negatives to include.
    n_neg_hard : int
        Number of hard negatives to include.
    seed : int
        Random seed for reproducibility.

    Returns
    -------
    X_explain : pd.DataFrame
        Explanation feature set.
    y_explain : pd.Series
        Explanation target values.
    """
    rng = np.random.default_rng(seed)

    yv = np.asarray(y_valid)
    idx_pos = np.flatnonzero(yv == 1)
    idx_neg = np.flatnonzero(yv == 0)

    # predict once on valid to pick hard negatives
    p_valid = booster.predict(X_valid, num_iteration=booster.best_iteration)

    # random negatives
    n_neg_rand = min(n_neg_rand, idx_neg.size)
    idx_neg_rand = rng.choice(idx_neg, size=n_neg_rand, replace=False)

    # hard negatives (top predicted p among negatives)
    n_neg_hard = min(n_neg_hard, idx_neg.size)
    p_neg = p_valid[idx_neg]
    hard_local = np.argpartition(p_neg, -n_neg_hard)[-n_neg_hard:]
    idx_neg_hard = idx_neg[hard_local]

    idx = np.unique(np.concatenate([idx_pos, idx_neg_rand, idx_neg_hard]))
    rng.shuffle(idx)

    X_eval = X_valid.iloc[idx].astype(np.float64).replace({pd.NA: np.nan}).copy()
    X_eval[categorical] = X_eval[categorical].astype("category")
    y_eval = pd.Series(yv[idx], index=X_valid.index[idx])
    return X_eval, y_eval

## Evaluation subset for permutation importances and SHAP analysis


In [None]:
X_eval, y_eval = build_explain_set(gbm, X_valid, y_valid, categorical)

In [None]:
model_wrapped = LGBMEstimator(gbm)

In [None]:
result = permutation_importance(
    model_wrapped,
    X_eval,
    y_eval,
    scoring=ap_scorer,
    n_repeats=10,
    n_jobs=NUM_THREADS,
    random_state=RANDOM_SEED,
)

perm_importance = pd.DataFrame(
    {
        "feature": X_eval.columns,
        "importance_mean": result.importances_mean,
        "importance_std": result.importances_std,
    }
).sort_values("importance_mean", ascending=False)

perm_importance.to_csv(
    f"{OUTPUT_DIR}/permutation_importance_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
perm_importance

In [None]:
plot_permutation_importances(result, X_eval, model_idx, save=SAVE_PLOTS)

In [None]:
distance, corr = stats_utils.distance_corr_dissimilarity(X_eval)
condensed = squareform(distance, checks=True)
dist_linkage = hierarchy.ward(condensed)

In [None]:
plt.figure(figsize=(5, 8))
dendro_0 = hierarchy.dendrogram(
    dist_linkage, labels=X_eval.columns.to_list(), orientation="right", ax=plt.axes()
)
plt.vlines(0.5, 0, 500, linestyle="--", color="#b2b4549f", linewidth=2)
plt.xlabel("Ward linkage distance (increase in within-cluster variance)")
plt.ylabel("Predictors")
plt.title(f"Model {model_idx}: Hierarchical clustering of predictors")
plt.show()

In [None]:
dendro_idx = np.arange(0, len(dendro_0["ivl"]))

with plt.rc_context(
    {"ytick.labelsize": 12, "xtick.labelsize": 12, "axes.titlesize": 12}
):
    plt.figure(figsize=(18, 18))
    plt.set_cmap("viridis")
    ax = plt.axes()
    im = ax.imshow(corr[dendro_0["leaves"], :][:, dendro_0["leaves"]])
    ax.set_title(f"Model {model_idx}: Correlation heatmap of predictors")
    ax.set_xticks(dendro_idx)
    ax.set_yticks(dendro_idx)
    ax.set_xticklabels(dendro_0["ivl"], rotation="vertical")
    ax.set_yticklabels(dendro_0["ivl"])
    plt.colorbar(im, ax=ax, fraction=0.03, pad=0.025)
    plt.show()

In [None]:
explainer = shap.TreeExplainer(
    gbm, feature_perturbation="tree_path_dependent", model_output="raw"
)

explanation = explainer(X_eval)

clustering = shap.utils.hclust(
    X_eval, y_eval, linkage="average", random_state=RANDOM_SEED
)

shap_values = explanation.values

shap_importance = pd.DataFrame(
    {
        "feature": X_eval.columns,
        "mean_abs_shap": np.mean(np.abs(shap_values), axis=0),
    }
).sort_values("mean_abs_shap", ascending=False)

shap_importance.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_shap_importance_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
shap_importance

In [None]:
shap_importance["feature"].tolist()

We cannot plot millions of observations, so for SHAP analysis, we take a random selection of 10,000 positives and 50,000 negatives.


In [None]:
with plt.rc_context({"axes.titlesize": 12}):
    plot = plt.figure(figsize=(8, 10))
    ax = plot.subplots()
    ax.set_title(f"Model 0: SHAP values for predictor variables")
    shap.plots.bar(explanation, clustering=clustering, max_display=35, ax=ax)

In [None]:
with plt.rc_context({"axes.titlesize": 12}):
    plot = plt.figure()
    ax = plot.subplots()
    ax.set_title(
        f"Model {model_idx}: SHAP values for predictor variables"
    )
    shap.plots.beeswarm(explanation, max_display=35, plot_size=(8, 10))

In [None]:
shap.plots.scatter(explanation[:, "year"], color=explanation[:, "mage_c"])

## Model 1


In [None]:
features_to_remove_0 = [
    "fracehisp",  # correlated with mracehisp
    "ab_aven6",  # correlated with ab_aven1
    "pwgt_r",  # low importance, correlated with bmi
    "ld_indl",  # low importance
    "no_risks",  # low importance
    "feduc",  # low importance, correlated with meduc
    "pay_rec",  # low importance
    "rf_ghype",  # low importance
    "ab_surf",  # low importance
    "rf_inftr",  # low importance
    "ld_augm",  # low importance
    "rf_gdiab",  # low importance
    "sex",  # low importance
    "ab_seiz",  # low importance
    "rf_ppterm",  # low importance
    "wic",  # low importance
    "rf_artec",  # low importance
    "rf_fedrg",  # low importance
    "apgar10",  # low importance, low SHAP value
]

X_train = X_train.drop(columns=features_to_remove_0)
X_valid = X_valid.drop(columns=features_to_remove_0)

features = X_train.columns.to_list()
categorical = [col for col in categorical if col not in features_to_remove_0]

train_data = lgb.Dataset(
    X_train, label=y_train, categorical_feature=categorical, free_raw_data=False
)

valid_data = lgb.Dataset(
    X_valid,
    label=y_valid,
    categorical_feature=categorical,
    reference=train_data,
    free_raw_data=False,
)

model_idx = 1
model_name = (
    f"Model {model_idx} (after removing {len(features_to_remove_0)} predictors)"
)

print(f"Features removed: {features_to_remove_0}")
print(f"Features remaining: {features}")
print(f"Number of features in {model_name}: {X_train.shape[1]}")

In [None]:
gbm = lgb.train(
    params,
    train_data,
    num_boost_round=NUM_BOOST_ROUND,
    valid_sets=[train_data, valid_data],
    valid_names=["train", "valid"],
    callbacks=[
        early_stopping(stopping_rounds=EARLY_STOPPING_ROUNDS),
        log_evaluation(period=10),
    ],
)

best_iter = gbm.best_iteration
gbm.save_model(
    f"{OUTPUT_DIR}/model_{model_idx}_{datetime.now().strftime("%Y%m%d%H%M")}.txt",
    num_iteration=best_iter,
)

In [None]:
p_valid = gbm.predict(X_valid, num_iteration=best_iter)

In [None]:
metrics_df, p_valid_fpr, p_valid_tpr, p_valid_thresholds = get_metrics(
    y_valid, p_valid, K=10000, thr=0.01
)

metrics_df.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_validation_metrics_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)

metrics_df

In [None]:
plot_roc_curve(p_valid_fpr, p_valid_tpr, model_idx, save=SAVE_PLOTS)

In [None]:
plot_precision_recall_curve(p_valid_fpr, p_valid_tpr, model_idx, save=SAVE_PLOTS)

In [None]:
importance_gain = gbm.feature_importance(importance_type="gain")

df_imp_gain = pd.DataFrame(
    {"feature": features, "importance_gain": importance_gain}
).sort_values("importance_gain", ascending=False)

df_imp_gain.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_feature_importance_gain_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
df_imp_gain

In [None]:
X_eval, y_eval = build_explain_set(gbm, X_valid, y_valid, categorical)

model_wrapped = LGBMEstimator(gbm)

In [None]:
result = permutation_importance(
    model_wrapped,
    X_eval,
    y_eval,
    scoring=ap_scorer,
    n_repeats=10,
    n_jobs=NUM_THREADS,
    random_state=RANDOM_SEED,
)

perm_importance = pd.DataFrame(
    {
        "feature": X_eval.columns,
        "importance_mean": result.importances_mean,
        "importance_std": result.importances_std,
    }
).sort_values("importance_mean", ascending=False)

perm_importance.to_csv(
    f"{OUTPUT_DIR}/permutation_importance_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
perm_importance

In [None]:
plot_permutation_importances(result, X_eval, model_idx, save=SAVE_PLOTS)

In [None]:
distance, corr = stats_utils.distance_corr_dissimilarity(X_eval)
condensed = squareform(distance, checks=True)
dist_linkage = hierarchy.ward(condensed)

In [None]:
plt.figure(figsize=(5, 6))
dendro_0 = hierarchy.dendrogram(
    dist_linkage, labels=X_eval.columns.to_list(), orientation="right", ax=plt.axes()
)
plt.vlines(0.5, 0, 500, linestyle="--", color="#b2b4549f", linewidth=2)
plt.xlabel("Ward linkage distance (increase in within-cluster variance)")
plt.ylabel("Predictors")
plt.title(f"Model {model_idx}: Hierarchical clustering of predictors")
plt.show()

In [None]:
C = corr[dendro_0["leaves"], :][:, dendro_0["leaves"]]
labels = dendro_0["ivl"]
dendro_idx = np.arange(len(labels))

threshold = 0.30  # adjust: 0.4, 0.5, 0.6 ...

with plt.rc_context(
    {"ytick.labelsize": 12, "xtick.labelsize": 12, "axes.titlesize": 12}
):
    fig, ax = plt.subplots(figsize=(12, 12))
    im = ax.imshow(C, cmap="viridis")

    ax.set_title(f"Model {model_idx}: Correlation heatmap of predictors")
    ax.set_xticks(dendro_idx)
    ax.set_yticks(dendro_idx)
    ax.set_xticklabels(labels, rotation="vertical")
    ax.set_yticklabels(labels)

    n = C.shape[0]
    for i in range(n):
        for j in range(n):
            if i == j:
                continue  # skip diagonal
            if abs(C[i, j]) < threshold:
                continue

            ax.text(
                j,
                i,
                f"{C[i, j]:.2f}",
                ha="center",
                va="center",
                fontsize=8,
                color="white" if abs(C[i, j]) < 0.6 else "black",
            )

    plt.colorbar(im, ax=ax, fraction=0.03, pad=0.025)
    plt.show()

In [None]:
explainer = shap.TreeExplainer(
    gbm, feature_perturbation="tree_path_dependent", model_output="raw"
)

explanation = explainer(X_eval)

clustering = shap.utils.hclust(
    X_eval, y_eval, linkage="average", random_state=RANDOM_SEED
)

shap_values = explanation.values

shap_importance = pd.DataFrame(
    {
        "feature": X_eval.columns,
        "mean_abs_shap": np.mean(np.abs(shap_values), axis=0),
    }
).sort_values("mean_abs_shap", ascending=False)

shap_importance.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_shap_importance_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
shap_importance

In [None]:
shap_importance["feature"].tolist()

In [None]:
with plt.rc_context({"axes.titlesize": 12}):
    plot = plt.figure(figsize=(8, 8))
    ax = plot.subplots()
    ax.set_title(f"Model {model_idx}}: SHAP values for predictor variables")
    shap.plots.bar(explanation, max_display=35, ax=ax)

In [None]:
with plt.rc_context({"axes.titlesize": 12}):
    plot = plt.figure()
    ax = plot.subplots()
    ax.set_title(
        f"Model {model_idx}: SHAP values for predictor variables"
    )
    shap.plots.beeswarm(explanation, max_display=30, plot_size=(8, 7))

In [None]:
shap.plots.scatter(explanation[:, "year"], color=explanation[:, "mage_c"])

## Model 2


In [None]:
features_to_remove_1 = [
    "precare",  # low importance, low SHAP value
    "bmi",  # low importance, low SHAP value
    "meduc",  # low importance, low SHAP value
    "fagecomb",  # low importance, low SHAP value, correlated with mage_c
    "rf_pdiab",  # low importance, low SHAP value
    "apgar5",  # low importance, low SHAP value
    "bfacil3",  # low importance, low SHAP value
    "dmeth_rec",  # low importance, low SHAP value
    "rf_phype",  # low importance, low SHAP value
    "me_pres",  # low importance, low SHAP value
    "mracehisp",  # partially correlated with year, possible bias
    "wtgain",  # low importance
]

X_train = X_train.drop(columns=features_to_remove_1)
X_valid = X_valid.drop(columns=features_to_remove_1)

features = X_train.columns.to_list()
categorical = [col for col in categorical if col not in features_to_remove_1]

train_data = lgb.Dataset(
    X_train, label=y_train, categorical_feature=categorical, free_raw_data=False
)

valid_data = lgb.Dataset(
    X_valid,
    label=y_valid,
    categorical_feature=categorical,
    reference=train_data,
    free_raw_data=False,
)

model_idx = 2
model_name = (
    f"Model {model_idx} (after removing {len(features_to_remove_1)} predictors)"
)

print(f"Features removed: {features_to_remove_1}")
print(f"Features remaining: {features}")
print(f"Number of features in {model_name}: {X_train.shape[1]}")

In [None]:
gbm = lgb.train(
    params,
    train_data,
    num_boost_round=NUM_BOOST_ROUND,
    valid_sets=[train_data, valid_data],
    valid_names=["train", "valid"],
    callbacks=[
        early_stopping(stopping_rounds=EARLY_STOPPING_ROUNDS),
        log_evaluation(period=10),
    ],
)

best_iter = gbm.best_iteration
gbm.save_model(
    f"{OUTPUT_DIR}/model_{model_idx}_{datetime.now().strftime("%Y%m%d%H%M")}.txt",
    num_iteration=best_iter,
)

In [None]:
p_valid = gbm.predict(X_valid, num_iteration=best_iter)

In [None]:
metrics_df, p_valid_fpr, p_valid_tpr, p_valid_thresholds = get_metrics(
    y_valid, p_valid, K=10000, thr=0.01
)

metrics_df.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_validation_metrics_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)

metrics_df

In [None]:
plot_roc_curve(p_valid_fpr, p_valid_tpr, model_idx, save=SAVE_PLOTS)

In [None]:
plot_precision_recall_curve(p_valid_fpr, p_valid_tpr, model_idx, save=SAVE_PLOTS)

In [None]:
importance_gain = gbm.feature_importance(importance_type="gain")

df_imp_gain = pd.DataFrame(
    {"feature": features, "importance_gain": importance_gain}
).sort_values("importance_gain", ascending=False)

df_imp_gain.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_feature_importance_gain_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
df_imp_gain

In [None]:
X_eval, y_eval = build_explain_set(gbm, X_valid, y_valid, categorical)

model_wrapped = LGBMEstimator(gbm)

In [None]:
result = permutation_importance(
    model_wrapped,
    X_eval,
    y_eval,
    scoring=ap_scorer,
    n_repeats=10,
    n_jobs=NUM_THREADS,
    random_state=RANDOM_SEED,
)

perm_importance = pd.DataFrame(
    {
        "feature": X_eval.columns,
        "importance_mean": result.importances_mean,
        "importance_std": result.importances_std,
    }
).sort_values("importance_mean", ascending=False)

perm_importance.to_csv(
    f"{OUTPUT_DIR}/permutation_importance_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
perm_importance

In [None]:
plot_permutation_importances(result, X_eval, model_idx, save=SAVE_PLOTS)

In [None]:
distance, corr = stats_utils.distance_corr_dissimilarity(X_eval)
condensed = squareform(distance, checks=True)
dist_linkage = hierarchy.ward(condensed)

In [None]:
plt.figure(figsize=(5, 4))
dendro_0 = hierarchy.dendrogram(
    dist_linkage, labels=X_eval.columns.to_list(), orientation="right", ax=plt.axes()
)
plt.vlines(0.5, 0, 500, linestyle="--", color="#b2b4549f", linewidth=2)
plt.xlabel("Ward linkage distance (increase in within-cluster variance)")
plt.ylabel("Predictors")
plt.title(f"Model {model_idx}: Hierarchical clustering of predictors")
plt.show()

In [None]:
C = corr[dendro_0["leaves"], :][:, dendro_0["leaves"]]
labels = dendro_0["ivl"]
dendro_idx = np.arange(len(labels))

threshold = 0.30  # adjust: 0.4, 0.5, 0.6 ...

with plt.rc_context(
    {"ytick.labelsize": 12, "xtick.labelsize": 12, "axes.titlesize": 12}
):
    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(C, cmap="viridis")

    ax.set_title(f"Model {model_idx}: Correlation heatmap of predictors")
    ax.set_xticks(dendro_idx)
    ax.set_yticks(dendro_idx)
    ax.set_xticklabels(labels, rotation="vertical")
    ax.set_yticklabels(labels)

    n = C.shape[0]
    for i in range(n):
        for j in range(n):
            if i == j:
                continue  # skip diagonal
            if abs(C[i, j]) < threshold:
                continue

            ax.text(
                j,
                i,
                f"{C[i, j]:.2f}",
                ha="center",
                va="center",
                fontsize=8,
                color="white" if abs(C[i, j]) < 0.6 else "black",
            )

    plt.colorbar(im, ax=ax, fraction=0.03, pad=0.025)
    plt.show()

In [None]:
explainer = shap.TreeExplainer(
    gbm, feature_perturbation="tree_path_dependent", model_output="raw"
)

explanation = explainer(X_eval)

clustering = shap.utils.hclust(
    X_eval, y_eval, linkage="average", random_state=RANDOM_SEED
)

shap_values = explanation.values

shap_importance = pd.DataFrame(
    {
        "feature": X_eval.columns,
        "mean_abs_shap": np.mean(np.abs(shap_values), axis=0),
    }
).sort_values("mean_abs_shap", ascending=False)

shap_importance.to_csv(
    f"{OUTPUT_DIR}/model_{model_idx}_shap_importance_{datetime.now().strftime("%Y%m%d%H%M")}.csv",
    index=False,
)
shap_importance

In [None]:
shap_importance["feature"].tolist()

In [None]:
with plt.rc_context({"axes.titlesize": 12}):
    plot = plt.figure(figsize=(8, 8))
    ax = plot.subplots()
    ax.set_title(
        f"Model {model_idx}: SHAP values for predictor variables"
    )
    shap.plots.bar(explanation, max_display=35, ax=ax)

In [None]:
with plt.rc_context({"axes.titlesize": 12}):
    plot = plt.figure()
    ax = plot.subplots()
    ax.set_title(
        f"Model {model_idx}: SHAP values for predictor variables"
    )
    shap.plots.beeswarm(explanation, max_display=30, plot_size=(8, 7))

In [None]:
shap.plots.scatter(explanation[:, "year"], color=explanation[:, "mage_c"])

Model 2

## Calibration


In [None]:
p_cal_raw = gbm.predict(X_cal, num_iteration=gbm.best_iteration, raw_score=True)

In [None]:
calibrator = LogisticRegression(solver="lbfgs", max_iter=1000)
calibrator.fit(p_cal_raw.reshape(-1, 1), y_cal)

In [None]:
def predict_proba_calibrated(gbm, calibrator, X_new, num_iteration=None):
    p_raw = gbm.predict(X_new, num_iteration=num_iteration, raw_score=True)
    return calibrator.predict_proba(p_raw.reshape(-1, 1))[:, 1]

## Prediction


In [None]:
p_valid_cal = predict_proba_calibrated(
    gbm, calibrator, X_valid, num_iteration=gbm.best_iteration
)
p_valid_raw = gbm.predict(X_valid, num_iteration=gbm.best_iteration)

print("Raw logloss:", log_loss(y_valid, p_valid_raw, labels=[0, 1]))
print("Cal logloss:", log_loss(y_valid, p_valid_cal, labels=[0, 1]))
print("Raw brier:  ", brier_score_loss(y_valid, p_valid_raw))
print("Cal brier:  ", brier_score_loss(y_valid, p_valid_cal))