# 001 — CatBoost Classification

Binary classification with CatBoost, Optuna hyperparameter tuning, threshold
optimization, SHAP explanations, and structured metrics output.

**Lifecycle stage:** seedling (model-garden)

All code is self-contained in this notebook — no external library imports
from a shared `src/` package.

In [None]:
# ---------------------------------------------------------------------------
# Papermill parameters  (this cell is tagged "parameters")
# ---------------------------------------------------------------------------

# Data loading
feature_paths: list[str] = []          # local or gs:// URIs; empty → synthetic
target_paths: list[str] = []           # optional separate target files
join_key: str | None = None            # key to join features ↔ targets
feature_cols: list[str] | None = None  # subset of columns; None → all
target_col: str = "target"
positive_label: str | int | None = None
entity_id: str | None = "entity_id"    # column for group-based train/test split

# Splitting
test_size: float = 0.2
random_state: int = 42
stratify: bool = True

# Optuna
optuna_n_trials: int = 30
optuna_timeout_s: int | None = None
optimize_metric: str = "f1"  # "f1", "precision", "recall"
threshold_grid: list[float] = [round(x * 0.05, 2) for x in range(1, 20)]

# Outputs
metrics_json_path: str = "outputs/metrics/metrics.json"
model_output_path: str = "outputs/models/model.cbm"
executed_notebook_path: str | None = None
plots_dir: str = "outputs/plots"

# SHAP / feature importance
shap_sample_size: int = 1000
enable_shap: bool = True
enable_feature_importance: bool = True

In [None]:
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
import json
import os
import warnings
from datetime import datetime, timezone
from pathlib import Path

# Suppress all warnings (keeps notebook output and rendered HTML clean)
warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from catboost import CatBoostClassifier, Pool
from sklearn.datasets import make_classification
from sklearn.metrics import (
    auc,
    average_precision_score,
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from sklearn.calibration import calibration_curve
from sklearn.model_selection import GroupShuffleSplit, StratifiedKFold, train_test_split

# Ensure output dirs exist
for d in ["outputs/runs", "outputs/plots", "outputs/models", "outputs/metrics"]:
    Path(d).mkdir(parents=True, exist_ok=True)

print(f"Run started at {datetime.now(timezone.utc).isoformat()}")

## 1 — Data Loading

In [None]:
# ---------------------------------------------------------------------------
# Data loading helpers (all inline, no external modules)
# ---------------------------------------------------------------------------

def _read_file(path: str) -> pl.DataFrame:
    """Read a single parquet or csv file (local or gs://)."""
    p = path.strip()
    if p.endswith(".parquet") or p.endswith(".pq"):
        return pl.read_parquet(p)
    return pl.read_csv(p)


def load_data(
    feature_paths: list[str],
    target_paths: list[str],
    join_key: str | None,
    feature_cols: list[str] | None,
    target_col: str,
    entity_id: str | None,
) -> pl.DataFrame:
    """Load features (and optional targets), returning a single DataFrame."""

    if not feature_paths:
        # Generate synthetic dataset with entity_id
        print("No feature_paths provided — generating synthetic dataset.")
        n_samples = 5_000
        X, y = make_classification(
            n_samples=n_samples,
            n_features=20,
            n_informative=12,
            n_redundant=4,
            n_classes=2,
            weights=[0.7, 0.3],
            flip_y=0.03,
            random_state=random_state,
        )
        cols = {f"feat_{i:02d}": X[:, i] for i in range(X.shape[1])}
        cols[target_col] = y
        # Assign ~3 rows per entity on average so group split is meaningful
        if entity_id:
            rng = np.random.RandomState(random_state)
            n_entities = n_samples // 3
            cols[entity_id] = rng.randint(0, n_entities, size=n_samples)
        return pl.DataFrame(cols)

    # Read and concat feature files
    dfs = [_read_file(p) for p in feature_paths]
    df = pl.concat(dfs, how="vertical_relaxed")

    # Optionally join separate target files
    if target_paths:
        tgt_dfs = [_read_file(p) for p in target_paths]
        tgt = pl.concat(tgt_dfs, how="vertical_relaxed")
        if join_key:
            df = df.join(tgt, on=join_key, how="inner")
        else:
            df = pl.concat([df, tgt], how="horizontal")

    # Subset columns if requested
    if feature_cols is not None:
        keep = list(set(feature_cols + [target_col] + ([entity_id] if entity_id else [])))
        df = df.select([c for c in keep if c in df.columns])

    return df


df = load_data(feature_paths, target_paths, join_key, feature_cols, target_col, entity_id)
print(f"Loaded DataFrame: {df.shape[0]:,} rows × {df.shape[1]} cols")
if entity_id and entity_id in df.columns:
    print(f"Unique entities ({entity_id}): {df[entity_id].n_unique():,}")

In [None]:
# Encode positive_label if provided
if positive_label is not None:
    df = df.with_columns(
        (pl.col(target_col).cast(pl.Utf8) == str(positive_label))
        .cast(pl.Int8)
        .alias(target_col)
    )
    print(f"Recoded target: positive_label='{positive_label}' → 1")

## 2 — EDA

In [None]:
# ---------------------------------------------------------------------------
# Schema & nulls
# ---------------------------------------------------------------------------
print("Schema:")
for name, dtype in zip(df.columns, df.dtypes):
    null_ct = df[name].null_count()
    print(f"  {name:30s}  {str(dtype):15s}  nulls={null_ct}")

print(f"\nTarget distribution ({target_col}):")
print(df[target_col].value_counts().sort(target_col))

In [None]:
# ---------------------------------------------------------------------------
# Summary stats for numeric columns
# ---------------------------------------------------------------------------
numeric_cols = [c for c in df.columns if df[c].dtype in (pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64)]
feature_numeric = [c for c in numeric_cols if c != target_col]

if feature_numeric:
    print(df.select(feature_numeric).describe())

In [None]:
# ---------------------------------------------------------------------------
# Histograms (first 8 numeric features)
# ---------------------------------------------------------------------------
plot_cols = feature_numeric[:8]
if plot_cols:
    n = len(plot_cols)
    ncols = min(4, n)
    nrows = (n + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows))
    axes_flat = np.array(axes).flatten() if n > 1 else [axes]
    for i, col in enumerate(plot_cols):
        axes_flat[i].hist(df[col].drop_nulls().to_numpy(), bins=40, edgecolor="white")
        axes_flat[i].set_title(col, fontsize=10)
    for j in range(i + 1, len(axes_flat)):
        axes_flat[j].set_visible(False)
    fig.tight_layout()
    fig.savefig(f"{plots_dir}/eda_histograms.png", dpi=120)
    plt.show()
    print(f"Saved → {plots_dir}/eda_histograms.png")

In [None]:
# ---------------------------------------------------------------------------
# Correlation heatmap (cap at 20 features to keep runtime sane)
# ---------------------------------------------------------------------------
corr_cols = feature_numeric[:20]
if len(corr_cols) >= 2:
    corr_df = df.select(corr_cols).to_pandas()
    corr = corr_df.corr()
    fig, ax = plt.subplots(figsize=(max(6, len(corr_cols) * 0.6), max(5, len(corr_cols) * 0.5)))
    im = ax.imshow(corr.values, aspect="auto", cmap="RdBu_r", vmin=-1, vmax=1)
    ax.set_xticks(range(len(corr_cols)))
    ax.set_yticks(range(len(corr_cols)))
    ax.set_xticklabels(corr_cols, rotation=90, fontsize=7)
    ax.set_yticklabels(corr_cols, fontsize=7)
    fig.colorbar(im, ax=ax, shrink=0.8)
    ax.set_title("Pairwise Correlation")
    fig.tight_layout()
    fig.savefig(f"{plots_dir}/eda_correlation.png", dpi=120)
    plt.show()
    print(f"Saved → {plots_dir}/eda_correlation.png")

## 3 — Train / Test Split

In [None]:
# ---------------------------------------------------------------------------
# Split — group-based on entity_id (no entity leaks between train/test)
# ---------------------------------------------------------------------------
non_feature_cols = {target_col}
if entity_id and entity_id in df.columns:
    non_feature_cols.add(entity_id)

feature_names = [c for c in df.columns if c not in non_feature_cols]
X_all = df.select(feature_names).to_numpy()
y_all = df[target_col].to_numpy().astype(int)

if entity_id and entity_id in df.columns:
    groups = df[entity_id].to_numpy()
    gss = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    train_idx, test_idx = next(gss.split(X_all, y_all, groups=groups))
    X_train, X_test = X_all[train_idx], X_all[test_idx]
    y_train, y_test = y_all[train_idx], y_all[test_idx]
    n_train_entities = len(set(groups[train_idx]))
    n_test_entities = len(set(groups[test_idx]))
    assert len(set(groups[train_idx]) & set(groups[test_idx])) == 0, "Entity leak detected!"
    print(f"Group split on '{entity_id}': {n_train_entities:,} train entities, {n_test_entities:,} test entities")
else:
    split_kw = dict(test_size=test_size, random_state=random_state)
    if stratify:
        split_kw["stratify"] = y_all
    X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, **split_kw)
    print("Random split (no entity_id column found)")

# Compute scale_pos_weight from training set class imbalance
n_neg = int((y_train == 0).sum())
n_pos = int((y_train == 1).sum())
scale_pos_weight = n_neg / n_pos

print(f"Train: {X_train.shape[0]:,}  |  Test: {X_test.shape[0]:,}")
print(f"Train target rate: {y_train.mean():.3f}  |  Test target rate: {y_test.mean():.3f}")
print(f"Class balance — neg: {n_neg:,}  pos: {n_pos:,}  → scale_pos_weight: {scale_pos_weight:.4f}")

## 4 — Optuna Hyperparameter Tuning

In [None]:
# ---------------------------------------------------------------------------
# Optuna objective
# ---------------------------------------------------------------------------
import optuna

optuna.logging.set_verbosity(optuna.logging.WARNING)

_METRIC_FN = {
    "f1": f1_score,
    "precision": precision_score,
    "recall": recall_score,
}


def objective(trial: optuna.Trial) -> float:
    params = {
        "iterations": trial.suggest_int("iterations", 200, 1500),
        "depth": trial.suggest_int("depth", 3, 10),
        "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
        "l2_leaf_reg": trial.suggest_float("l2_leaf_reg", 1e-2, 10.0, log=True),
        "border_count": trial.suggest_int("border_count", 32, 255),
        "subsample": trial.suggest_float("subsample", 0.5, 1.0),
        "random_strength": trial.suggest_float("random_strength", 1e-3, 10.0, log=True),
        "scale_pos_weight": scale_pos_weight,
    }

    metric_fn = _METRIC_FN[optimize_metric]
    skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=random_state)
    fold_scores = []

    for train_idx, val_idx in skf.split(X_train, y_train):
        xtr, xvl = X_train[train_idx], X_train[val_idx]
        ytr, yvl = y_train[train_idx], y_train[val_idx]

        model = CatBoostClassifier(
            **params,
            eval_metric="Logloss",
            random_seed=random_state,
            verbose=0,
            early_stopping_rounds=50,
        )
        model.fit(xtr, ytr, eval_set=(xvl, yvl), verbose=0)

        proba = model.predict_proba(xvl)[:, 1]
        best_score = max(
            metric_fn(yvl, (proba >= t).astype(int), zero_division=0)
            for t in threshold_grid
        )
        fold_scores.append(best_score)

    return float(np.mean(fold_scores))


study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=optuna_n_trials, timeout=optuna_timeout_s)

best_params = study.best_params
print(f"Best CV {optimize_metric}: {study.best_value:.4f}")
print(f"Best params: {json.dumps(best_params, indent=2)}")

## 5 — Train Final Model

In [None]:
# ---------------------------------------------------------------------------
# Final model on full train set
# ---------------------------------------------------------------------------
final_model = CatBoostClassifier(
    **best_params,
    scale_pos_weight=scale_pos_weight,
    eval_metric="Logloss",
    random_seed=random_state,
    verbose=100,
)
final_model.fit(X_train, y_train)

# Save model
Path(model_output_path).parent.mkdir(parents=True, exist_ok=True)
final_model.save_model(model_output_path)
print(f"\nModel saved → {model_output_path}")
print(f"scale_pos_weight used: {scale_pos_weight:.4f}")

## 6 — Evaluation & Threshold Optimization

In [None]:
# ---------------------------------------------------------------------------
# Per-threshold metrics on test set
# ---------------------------------------------------------------------------
y_proba = final_model.predict_proba(X_test)[:, 1]

threshold_rows = []
for t in threshold_grid:
    y_pred_t = (y_proba >= t).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_test, y_pred_t, labels=[0, 1]).ravel()
    threshold_rows.append({
        "threshold": round(t, 4),
        "precision": round(precision_score(y_test, y_pred_t, zero_division=0), 4),
        "recall": round(recall_score(y_test, y_pred_t, zero_division=0), 4),
        "f1": round(f1_score(y_test, y_pred_t, zero_division=0), 4),
        "tp": int(tp),
        "fp": int(fp),
        "tn": int(tn),
        "fn": int(fn),
    })

threshold_df = pl.DataFrame(threshold_rows)
print(threshold_df)

In [None]:
# ---------------------------------------------------------------------------
# Best threshold per metric
# ---------------------------------------------------------------------------
best_f1_row = threshold_df.sort("f1", descending=True).row(0, named=True)
best_precision_row = threshold_df.sort("precision", descending=True).row(0, named=True)
best_recall_row = threshold_df.sort("recall", descending=True).row(0, named=True)

# Primary best threshold is based on optimize_metric
best_row = {"f1": best_f1_row, "precision": best_precision_row, "recall": best_recall_row}[optimize_metric]
best_threshold = best_row["threshold"]

print(f"Best threshold (by {optimize_metric}): {best_threshold}")
print(f"  precision={best_row['precision']:.4f}  recall={best_row['recall']:.4f}  f1={best_row['f1']:.4f}")
print()
print(f"Best F1 threshold:        {best_f1_row['threshold']}  (F1={best_f1_row['f1']:.4f})")
print(f"Best Precision threshold:  {best_precision_row['threshold']}  (Precision={best_precision_row['precision']:.4f})")
print(f"Best Recall threshold:     {best_recall_row['threshold']}  (Recall={best_recall_row['recall']:.4f})")

# AUC scores (threshold-independent)
roc_auc = roc_auc_score(y_test, y_proba)
pr_auc = average_precision_score(y_test, y_proba)
print(f"\nROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC (Average Precision): {pr_auc:.4f}")

In [None]:
# ---------------------------------------------------------------------------
# Individual metric vs threshold plots
# ---------------------------------------------------------------------------
thresholds = threshold_df["threshold"].to_list()

metric_configs = [
    ("f1", "F1 Score", best_f1_row, "#2196F3"),
    ("precision", "Precision", best_precision_row, "#4CAF50"),
    ("recall", "Recall", best_recall_row, "#FF9800"),
]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, (metric_key, metric_label, best_metric_row, color) in zip(axes, metric_configs):
    values = threshold_df[metric_key].to_list()
    best_thr = best_metric_row["threshold"]
    best_val = best_metric_row[metric_key]

    ax.plot(thresholds, values, marker=".", color=color, linewidth=2)
    ax.axvline(best_thr, color="grey", linestyle="--", alpha=0.7)
    ax.plot(best_thr, best_val, "r*", markersize=15, zorder=5)
    ax.annotate(
        f"  thr={best_thr}\n  {metric_label}={best_val:.4f}",
        xy=(best_thr, best_val),
        fontsize=8,
        color="red",
    )
    ax.set_xlabel("Threshold")
    ax.set_ylabel(metric_label)
    ax.set_title(f"{metric_label} vs Threshold")
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.05, 1.05)

fig.tight_layout()
fig.savefig(f"{plots_dir}/threshold_individual.png", dpi=120)
plt.show()
print(f"Saved → {plots_dir}/threshold_individual.png")

In [None]:
# ---------------------------------------------------------------------------
# Combined overlay: Precision / Recall / F1 vs Threshold
# ---------------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(thresholds, threshold_df["precision"].to_list(), label="Precision", marker=".", color="#4CAF50")
ax.plot(thresholds, threshold_df["recall"].to_list(), label="Recall", marker=".", color="#FF9800")
ax.plot(thresholds, threshold_df["f1"].to_list(), label="F1", marker=".", linewidth=2, color="#2196F3")
ax.axvline(best_threshold, color="red", linestyle="--", alpha=0.7, label=f"Best thr ({optimize_metric})={best_threshold}")
ax.set_xlabel("Threshold")
ax.set_ylabel("Score")
ax.set_title("Precision / Recall / F1 vs Threshold")
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(-0.05, 1.05)
fig.tight_layout()
fig.savefig(f"{plots_dir}/threshold_overlay.png", dpi=120)
plt.show()
print(f"Saved → {plots_dir}/threshold_overlay.png")

In [None]:
# ---------------------------------------------------------------------------
# ROC Curve
# ---------------------------------------------------------------------------
fpr, tpr, roc_thresholds = roc_curve(y_test, y_proba)

fig, ax = plt.subplots(figsize=(7, 6))
ax.plot(fpr, tpr, linewidth=2, label=f"ROC (AUC = {roc_auc:.4f})")
ax.plot([0, 1], [0, 1], "k--", alpha=0.4, label="Random classifier")

# Mark the operating point at best_threshold
y_pred_bt = (y_proba >= best_threshold).astype(int)
tn_bt, fp_bt, fn_bt, tp_bt = confusion_matrix(y_test, y_pred_bt, labels=[0, 1]).ravel()
fpr_bt = fp_bt / (fp_bt + tn_bt)
tpr_bt = tp_bt / (tp_bt + fn_bt)
ax.plot(fpr_bt, tpr_bt, "r*", markersize=15, zorder=5, label=f"Best thr={best_threshold}")

ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title("ROC Curve")
ax.legend(loc="lower right")
ax.grid(True, alpha=0.3)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)
fig.tight_layout()
fig.savefig(f"{plots_dir}/roc_curve.png", dpi=120)
plt.show()
print(f"Saved → {plots_dir}/roc_curve.png")

In [None]:
# ---------------------------------------------------------------------------
# Precision-Recall Curve
# ---------------------------------------------------------------------------
pr_precision, pr_recall, pr_thresholds = precision_recall_curve(y_test, y_proba)

fig, ax = plt.subplots(figsize=(7, 6))
ax.plot(pr_recall, pr_precision, linewidth=2, label=f"PR curve (AP = {pr_auc:.4f})")

# Baseline: prevalence of positive class
prevalence = y_test.mean()
ax.axhline(prevalence, color="k", linestyle="--", alpha=0.4, label=f"Baseline (prevalence={prevalence:.3f})")

# Mark the operating point at best_threshold
prec_bt = best_row["precision"]
rec_bt = best_row["recall"]
ax.plot(rec_bt, prec_bt, "r*", markersize=15, zorder=5, label=f"Best thr={best_threshold}")

ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title("Precision-Recall Curve")
ax.legend(loc="upper right")
ax.grid(True, alpha=0.3)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.05)
fig.tight_layout()
fig.savefig(f"{plots_dir}/pr_curve.png", dpi=120)
plt.show()
print(f"Saved → {plots_dir}/pr_curve.png")

In [None]:
# ---------------------------------------------------------------------------
# Predicted Probability Distribution by Class
# ---------------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(8, 5))

mask_neg = y_test == 0
mask_pos = y_test == 1
ax.hist(y_proba[mask_neg], bins=50, alpha=0.6, label="Class 0 (negative)", color="#2196F3", edgecolor="white")
ax.hist(y_proba[mask_pos], bins=50, alpha=0.6, label="Class 1 (positive)", color="#FF5722", edgecolor="white")
ax.axvline(best_threshold, color="red", linestyle="--", linewidth=2, label=f"Best thr={best_threshold}")
ax.set_xlabel("Predicted Probability")
ax.set_ylabel("Count")
ax.set_title("Predicted Probability Distribution by Class")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(f"{plots_dir}/probability_distribution.png", dpi=120)
plt.show()
print(f"Saved → {plots_dir}/probability_distribution.png")

In [None]:
# ---------------------------------------------------------------------------
# Calibration Curve
# ---------------------------------------------------------------------------
prob_true, prob_pred = calibration_curve(y_test, y_proba, n_bins=10, strategy="uniform")

fig, ax = plt.subplots(figsize=(7, 6))
ax.plot(prob_pred, prob_true, "s-", linewidth=2, label="Model")
ax.plot([0, 1], [0, 1], "k--", alpha=0.4, label="Perfectly calibrated")
ax.set_xlabel("Mean Predicted Probability")
ax.set_ylabel("Fraction of Positives")
ax.set_title("Calibration Curve (Reliability Diagram)")
ax.legend(loc="lower right")
ax.grid(True, alpha=0.3)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)
fig.tight_layout()
fig.savefig(f"{plots_dir}/calibration_curve.png", dpi=120)
plt.show()
print(f"Saved → {plots_dir}/calibration_curve.png")

In [None]:
# ---------------------------------------------------------------------------
# Cumulative Gains & Lift Chart
# ---------------------------------------------------------------------------
# Sort by predicted probability descending
order = np.argsort(-y_proba)
y_sorted = y_test[order]
n = len(y_sorted)
n_pos = y_sorted.sum()

# Cumulative gains: fraction of positives captured vs fraction of population examined
cum_pos = np.cumsum(y_sorted)
pct_population = np.arange(1, n + 1) / n
pct_captured = cum_pos / n_pos

# Lift: ratio of cumulative gains to random baseline
lift = pct_captured / pct_population

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Cumulative gains
ax = axes[0]
ax.plot(pct_population, pct_captured, linewidth=2, label="Model")
ax.plot([0, 1], [0, 1], "k--", alpha=0.4, label="Random")
ax.plot([0, n_pos / n, 1], [0, 1, 1], "g--", alpha=0.4, label="Perfect")
ax.set_xlabel("Fraction of Population")
ax.set_ylabel("Fraction of Positives Captured")
ax.set_title("Cumulative Gains Chart")
ax.legend(loc="lower right")
ax.grid(True, alpha=0.3)

# Lift
ax = axes[1]
# Downsample for smooth plotting
step = max(1, n // 200)
ax.plot(pct_population[::step], lift[::step], linewidth=2, color="#9C27B0")
ax.axhline(1.0, color="k", linestyle="--", alpha=0.4, label="Baseline (lift=1)")
ax.set_xlabel("Fraction of Population")
ax.set_ylabel("Lift")
ax.set_title("Lift Chart")
ax.legend()
ax.grid(True, alpha=0.3)

fig.tight_layout()
fig.savefig(f"{plots_dir}/cumulative_gains_lift.png", dpi=120)
plt.show()
print(f"Saved → {plots_dir}/cumulative_gains_lift.png")

In [None]:
# ---------------------------------------------------------------------------
# Confusion matrix at best threshold
# ---------------------------------------------------------------------------
y_pred_final = (y_proba >= best_threshold).astype(int)

fig, ax = plt.subplots(figsize=(5, 4))
ConfusionMatrixDisplay.from_predictions(y_test, y_pred_final, ax=ax, cmap="Blues")
ax.set_title(f"Confusion Matrix (threshold={best_threshold})")
fig.tight_layout()
fig.savefig(f"{plots_dir}/confusion_matrix.png", dpi=120)
plt.show()
print(f"Saved → {plots_dir}/confusion_matrix.png")

print("\nClassification Report:")
print(classification_report(y_test, y_pred_final))

## 7 — Metrics JSON Output

In [None]:
# ---------------------------------------------------------------------------
# Write structured metrics JSON
# ---------------------------------------------------------------------------
metrics_output = {
    "run_metadata": {
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "target_col": target_col,
        "entity_id": entity_id,
        "test_size": test_size,
        "random_state": random_state,
        "stratify": stratify,
        "scale_pos_weight": round(scale_pos_weight, 4),
        "optuna_n_trials": optuna_n_trials,
        "optimize_metric": optimize_metric,
        "n_train": int(X_train.shape[0]),
        "n_test": int(X_test.shape[0]),
        "n_features": int(X_train.shape[1]),
        "feature_names": feature_names,
        "best_optuna_cv_score": round(study.best_value, 4),
        "best_params": best_params,
        "best_threshold": best_threshold,
        "model_output_path": model_output_path,
        "roc_auc": round(roc_auc, 4),
        "pr_auc": round(pr_auc, 4),
    },
    "best_thresholds_per_metric": {
        "f1": {"threshold": best_f1_row["threshold"], "value": best_f1_row["f1"]},
        "precision": {"threshold": best_precision_row["threshold"], "value": best_precision_row["precision"]},
        "recall": {"threshold": best_recall_row["threshold"], "value": best_recall_row["recall"]},
    },
    "per_threshold": threshold_rows,
    "best_threshold_metrics": best_row,
}

Path(metrics_json_path).parent.mkdir(parents=True, exist_ok=True)
with open(metrics_json_path, "w") as f:
    json.dump(metrics_output, f, indent=2, default=str)

print(f"Metrics JSON saved → {metrics_json_path}")

## 8 — Feature Importance

In [None]:
# ---------------------------------------------------------------------------
# CatBoost feature importance
# ---------------------------------------------------------------------------
if enable_feature_importance:
    importances = final_model.get_feature_importance()
    imp_df = (
        pl.DataFrame({"feature": feature_names, "importance": importances})
        .sort("importance", descending=True)
    )
    top_n = min(20, len(imp_df))
    top = imp_df.head(top_n)

    fig, ax = plt.subplots(figsize=(8, max(4, top_n * 0.35)))
    ax.barh(top["feature"].to_list()[::-1], top["importance"].to_list()[::-1])
    ax.set_xlabel("Importance")
    ax.set_title(f"Top {top_n} Feature Importances")
    fig.tight_layout()
    fig.savefig(f"{plots_dir}/feature_importance.png", dpi=120)
    plt.show()
    print(f"Saved → {plots_dir}/feature_importance.png")
else:
    print("Feature importance disabled.")

## 9 — SHAP

In [None]:
# ---------------------------------------------------------------------------
# SHAP explanations
# ---------------------------------------------------------------------------
if enable_shap:
    import shap

    n_sample = min(shap_sample_size, X_test.shape[0])
    rng = np.random.RandomState(random_state)
    idx = rng.choice(X_test.shape[0], size=n_sample, replace=False)
    X_shap = X_test[idx]

    explainer = shap.TreeExplainer(final_model)
    shap_values = explainer.shap_values(X_shap)

    # Summary bar plot
    fig, ax = plt.subplots(figsize=(8, 6))
    shap.summary_plot(
        shap_values, X_shap,
        feature_names=feature_names,
        plot_type="bar",
        show=False,
    )
    plt.tight_layout()
    plt.savefig(f"{plots_dir}/shap_bar.png", dpi=120, bbox_inches="tight")
    plt.show()
    print(f"Saved → {plots_dir}/shap_bar.png")

    # Summary dot plot
    fig, ax = plt.subplots(figsize=(8, 6))
    shap.summary_plot(
        shap_values, X_shap,
        feature_names=feature_names,
        show=False,
    )
    plt.tight_layout()
    plt.savefig(f"{plots_dir}/shap_summary.png", dpi=120, bbox_inches="tight")
    plt.show()
    print(f"Saved → {plots_dir}/shap_summary.png")
else:
    print("SHAP disabled.")

## 10 — Summary

In [None]:
# ---------------------------------------------------------------------------
# Final summary
# ---------------------------------------------------------------------------
print("=" * 60)
print("RUN COMPLETE")
print("=" * 60)
print(f"  Model saved to:       {model_output_path}")
print(f"  Metrics JSON:         {metrics_json_path}")
print(f"  Plots directory:      {plots_dir}")
if executed_notebook_path:
    print(f"  Executed notebook:    {executed_notebook_path}")
print(f"  Best threshold:       {best_threshold}")
print(f"  F1 at best threshold: {best_row['f1']:.4f}")
print(f"  ROC-AUC:              {roc_auc:.4f}")
print(f"  PR-AUC:               {pr_auc:.4f}")
print("=" * 60)