In [None]:
"""
shap_credit_risk_pipeline.py

Requirements:
pip install numpy pandas scikit-learn xgboost shap matplotlib joblib

Place credit_risk_dataset.csv in the same directory as this script.
"""

import os
import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, classification_report
from sklearn.calibration import calibration_curve
from xgboost import XGBClassifier
import shap
import warnings
warnings.filterwarnings("ignore")

# -------------------------
# CONFIG
# -------------------------
DATA_PATH = "credit_risk_dataset.csv"   # file created earlier
OUTPUT_DIR = "outputs"
RANDOM_STATE = 42
TEST_SIZE = 0.25

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------
# LOAD DATA
# -------------------------
df = pd.read_csv(DATA_PATH)
print("Loaded dataset shape:", df.shape)
print(df.dtypes)
print(df.head())

# -------------------------
# PREPARE FEATURES / TARGET
# -------------------------
target_col = "default"
if target_col not in df.columns:
    raise ValueError(f"Target column '{target_col}' not found in dataset.")

X = df.drop(columns=[target_col])
y = df[target_col]

# Simple feature names list (all numeric in this synthetic dataset)
feature_names = X.columns.tolist()

# -------------------------
# SPLIT
# -------------------------
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=y
)

print("Train size:", X_train.shape, "Test size:", X_test.shape)

# -------------------------
# BASE MODEL + HYPERPARAMETER TUNING (RandomizedSearch)
# -------------------------
base_clf = XGBClassifier(
    objective="binary:logistic",
    use_label_encoder=False,
    eval_metric="logloss",
    random_state=RANDOM_STATE,
    n_jobs=1
)

param_dist = {
    "n_estimators": [100, 200, 300, 500],
    "max_depth": [3, 4, 5, 6, 8],
    "learning_rate": [0.01, 0.03, 0.05, 0.1],
    "subsample": [0.6, 0.8, 1.0],
    "colsample_bytree": [0.6, 0.8, 1.0],
    "reg_alpha": [0.0, 0.1, 0.5],
    "reg_lambda": [1.0, 2.0, 5.0],
}

rs = RandomizedSearchCV(
    estimator=base_clf,
    param_distributions=param_dist,
    n_iter=20,
    scoring="roc_auc",
    cv=3,
    random_state=RANDOM_STATE,
    verbose=1,
    n_jobs=-1
)

print("Starting RandomizedSearchCV...")
rs.fit(X_train, y_train)
print("Best params:", rs.best_params_)
print("Best CV ROC-AUC:", rs.best_score_)

best_model = rs.best_estimator_

# Save the trained model
model_path = os.path.join(OUTPUT_DIR, "xgb_best_model.joblib")
joblib.dump(best_model, model_path)
print("Saved model to", model_path)

# -------------------------
# PREDICTIONS & METRICS
# -------------------------
y_prob = best_model.predict_proba(X_test)[:, 1]
y_pred = (y_prob >= 0.5).astype(int)

roc_auc = roc_auc_score(y_test, y_prob)
precision, recall, _ = precision_recall_curve(y_test, y_prob)
pr_auc = auc(recall, precision)

print("\nTest ROC-AUC: {:.4f}".format(roc_auc))
print("Test PR-AUC: {:.4f}".format(pr_auc))
print("\nClassification report (threshold=0.5):")
print(classification_report(y_test, y_pred))

# Calibration curve
prob_true, prob_pred = calibration_curve(y_test, y_prob, n_bins=10)

plt.figure(figsize=(6,5))
plt.plot(prob_pred, prob_true, marker='o', linewidth=2)
plt.plot([0,1],[0,1], linestyle='--', label='Perfectly calibrated')
plt.xlabel("Predicted probability")
plt.ylabel("Observed probability")
plt.title("Calibration curve")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "calibration_curve.png"))
plt.close()
print("Saved calibration curve.")

# -------------------------
# SHAP ANALYSIS
# -------------------------
# Use TreeExplainer for XGBoost
explainer = shap.TreeExplainer(best_model, feature_perturbation="tree_path_dependent")
# compute shap values for test set
shap_values = explainer.shap_values(X_test)

# SUMMARY PLOT (dot)
plt.figure(figsize=(8,6))
shap.summary_plot(shap_values, X_test, show=False)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "shap_summary_dot.png"), bbox_inches="tight")
plt.close()
print("Saved shap_summary_dot.png")

# FEATURE IMPORTANCE (bar)
plt.figure(figsize=(8,6))
shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "shap_summary_bar.png"), bbox_inches="tight")
plt.close()
print("Saved shap_summary_bar.png")

# DEPENDENCE PLOT (choose top feature)
# Identify top feature by mean(|SHAP|)
mean_abs_shap = np.abs(shap_values).mean(axis=0)
top_idx = np.argmax(mean_abs_shap)
top_feature = X_test.columns[top_idx]
print("Top SHAP feature:", top_feature)

plt.figure(figsize=(8,6))
shap.dependence_plot(top_feature, shap_values, X_test, show=False)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, f"shap_dependence_{top_feature}.png"), bbox_inches="tight")
plt.close()
print(f"Saved shap_dependence_{top_feature}.png")

# FORCE PLOT for a single instance -> save as html
shap.initjs()
instance_idx = 0
force_html = os.path.join(OUTPUT_DIR, f"shap_force_instance_{instance_idx}.html")
force_plot = shap.force_plot(
    explainer.expected_value,
    shap_values[instance_idx,:],
    X_test.iloc[instance_idx,:],
    matplotlib=False
)
shap.save_html(force_html, force_plot)
print(f"Saved SHAP force plot html: {force_html}")

# -------------------------
# SAVE METRICS & FEATURE IMPORTANCE AS CSV
# -------------------------
metrics = {
    "roc_auc": roc_auc,
    "pr_auc": pr_auc,
    "best_params": rs.best_params_,
}
metrics_df = pd.DataFrame([metrics])
metrics_df.to_csv(os.path.join(OUTPUT_DIR, "metrics_summary.csv"), index=False)

# Save mean absolute shap per feature
feat_imp_df = pd.DataFrame({
    "feature": feature_names,
    "mean_abs_shap": mean_abs_shap
}).sort_values("mean_abs_shap", ascending=False)
feat_imp_df.to_csv(os.path.join(OUTPUT_DIR, "shap_feature_importance.csv"), index=False)

print("Saved metrics and feature importance CSVs in", OUTPUT_DIR)
print("All done. Check the 'outputs' folder for plots, model, and CSVs.")
