
# Explainability in Machine Learning with SHAP (Financial Use Case)

**Session goals**
- Understand *why* explainability matters in finance (trust, regulation, debugging).
- Learn how **SHAP** (Shapley Additive exPlanations) explains model predictions globally and locally.
- Apply SHAP to a **Random Forest** trained on a **credit risk** dataset.
- Practice interpreting explanations and turning them into business insights.



## 1. Why Explainability in Finance?

- **Regulation & Compliance:** Many jurisdictions require explainable decisions for lending/credit scoring.
- **Trust & Adoption:** Business users and auditors need to understand *why* predictions are made.
- **Debugging & Governance:** Explanations help detect data leakage, bias, or spurious correlations.

> **Key idea:** SHAP assigns each feature a contribution to a prediction, grounded in Shapley values from cooperative game theory.



## 2. Setup


In [None]:

# If running locally and you don't have these installed, uncomment:
# !pip install pandas numpy scikit-learn shap matplotlib openpyxl requests

import os
import io
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, classification_report, ConfusionMatrixDisplay

import shap

# Enable interactive JS for force plots
shap.initjs()


## 3. Load the Financial Dataset (UCI Credit Default)

We'll auto-download the **Default of Credit Card Clients** dataset from the UCI Machine Learning Repository.
If the first URL fails, a fallback mirror is tried.

**Target:** `default_payment_next_month` (1 = default, 0 = non-default)


In [None]:

import requests

def download_uci_credit_default():
    urls = [
        # Primary UCI link
        "https://archive.ics.uci.edu/ml/machine-learning-databases/00350/default%20of%20credit%20card%20clients.xls",
        # Backup mirror (if available)
        "https://archive.ics.uci.edu/ml/machine-learning-databases/00350/default%20of%20credit%20card%20clients.xls"
    ]
    for url in urls:
        try:
            r = requests.get(url, timeout=30)
            r.raise_for_status()
            return io.BytesIO(r.content)
        except Exception as e:
            print(f"Failed to download from {url}: {e}")
    raise RuntimeError("Could not download dataset from UCI. Please check connectivity or replace URL.")

bio = download_uci_credit_default()
# The relevant sheet is usually the first; header row is at row 1
df_raw = pd.read_excel(bio, header=1)

# Standardize column names
df_raw.columns = [str(c).strip().lower().replace(" ", "_").replace("-", "_") for c in df_raw.columns]

# Rename target for convenience if needed
if "default_payment_next_month" in df_raw.columns:
    df_raw.rename(columns={"default_payment_next_month": "default"}, inplace=True)

df_raw.head()


| Feature     | Description                                                                         | Example Values     |
| ----------- | ----------------------------------------------------------------------------------- | ------------------ |
| `LIMIT_BAL` | Amount of given credit (NT dollar), includes individual and family credit.          | 20,000 â€“ 1,000,000 |
| `SEX`       | Gender (1 = male, 2 = female).                                                      | 1, 2               |
| `EDUCATION` | Education level (1 = graduate school, 2 = university, 3 = high school, 4 = others). | 1â€“4                |
| `MARRIAGE`  | Marital status (1 = married, 2 = single, 3 = others).                               | 1â€“3                |
| `AGE`       | Age in years.                                                                       | 21â€“79              |


| Month (2005)  | Repayment Status (`PAY_X`)              | Bill Amount (`BILL_AMTX`)                     | Payment Amount (`PAY_AMTX`)           | What It Means                                                                                                |
| ------------- | --------------------------------------- | --------------------------------------------- | ------------------------------------- | ------------------------------------------------------------------------------------------------------------ |
| **September** | `PAY_0` â€” repayment status in September | `BILL_AMT1` â€” amount owed at end of September | `PAY_AMT1` â€” amount paid in September | The most recent month before prediction (October). Captures latest payment behavior and outstanding balance. |
| **August**    | `PAY_2` â€” repayment status in August    | `BILL_AMT2` â€” amount owed at end of August    | `PAY_AMT2` â€” amount paid in August    | Shows whether the client was starting to delay payments or carrying a growing balance.                       |
| **July**      | `PAY_3` â€” repayment status in July      | `BILL_AMT3` â€” bill at end of July             | `PAY_AMT3` â€” amount paid in July      | Midpoint of the 6-month history â€” reveals repayment trends.                                                  |
| **June**      | `PAY_4` â€” repayment status in June      | `BILL_AMT4` â€” bill at end of June             | `PAY_AMT4` â€” amount paid in June      |                                                                                                              |
| **May**       | `PAY_5` â€” repayment status in May       | `BILL_AMT5` â€” bill at end of May              | `PAY_AMT5` â€” amount paid in May       |                                                                                                              |
| **April**     | `PAY_6` â€” repayment status in April     | `BILL_AMT6` â€” bill at end of April            | `PAY_AMT6` â€” amount paid in April     | The oldest month in the dataset â€” starts the 6-month lookback period.                                        |


ðŸ’³ 1. Repayment Status Columns (PAY_0 â€“ PAY_6)

| Value   | Meaning                              | Interpretation                                                      |
| ------- | ------------------------------------ | ------------------------------------------------------------------- |
| **-2**  | No consumption that month            | The client didnâ€™t use the card at all, so there was no bill due.    |
| **-1**  | Paid in full                         | Client cleared the balance completely â€” healthy repayment behavior. |
| **0**   | Paid the minimum due or paid on time | Normal, revolving behavior; no delay.                               |
| **1**   | Payment delayed by 1 month           | Slight delay â€” mild risk signal.                                    |
| **2**   | Payment delayed by 2 months          | Missed two consecutive billing cycles â€” higher risk.                |
| **3â€“9** | Payment delayed by 3â€“9 months        | Serious delinquency â€” strong default indicator.                     |

ðŸ§¾ 2. Bill Statement Amount Columns (BILL_AMT1 â€“ BILL_AMT6)

| Value Type                        | Meaning                  | Interpretation                                             |
| --------------------------------- | ------------------------ | ---------------------------------------------------------- |
| **Positive value (e.g., 50,000)** | Outstanding bill balance | The total amount owed by the end of that month.            |
| **0**                             | No balance               | The client had no outstanding credit card debt that month. |
| **Negative (rare)**               | Overpayment / refund     | Occasionally occurs if the client paid more than owed.     |


ðŸ’µ 3. Payment Amount Columns (PAY_AMT1 â€“ PAY_AMT6)
| Value Type                       | Meaning                 | Interpretation                                                              |
| -------------------------------- | ----------------------- | --------------------------------------------------------------------------- |
| **Positive value (e.g., 5,000)** | Amount paid that month  | How much the client paid to reduce their balance.                           |
| **0**                            | No payment made         | Could indicate missed payment, especially if there was an outstanding bill. |
| **Very large value**             | Full or advance payment | The client paid off or exceeded their balance.                              |



## 4. Quick EDA

Let's explore the schema and class balance.


In [None]:

print("Shape:", df_raw.shape)
print("\nColumns:", list(df_raw.columns))
print("\nTarget distribution (default):")
print(df_raw["default"].value_counts(normalize=True).rename("share"))

df_raw.describe(include="all").T.head(15)


In [None]:
# Visualize target distribution
df_raw["default"].value_counts().plot(kind="bar")
plt.title("Target distribution: default vs non-default")
plt.xlabel("default (1=default, 0=non-default)")
plt.ylabel("count")
plt.show()



## 5. Minimal Cleaning & Feature Selection

We'll keep numeric/coded features as-is for a tree model. You can extend this step with domain-driven cleaning.


In [None]:

# Drop obvious non-feature if present
drop_cols = [c for c in ["id"] if c in df_raw.columns]
df = df_raw.drop(columns=drop_cols, errors="ignore").copy()

# Train/validation split
X = df.drop(columns=["default"])
y = df["default"].astype(int)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

X_train.shape, X_test.shape



## 6. Train a Random Forest Classifier


In [None]:

# Define a Random Forest model tuned for explainability and speed
rf = RandomForestClassifier(
    n_estimators=100,          # Number of trees in the forest.
                              # Fewer trees make the model faster to train and explain with SHAP.
                              # (400+ trees could slightly improve accuracy but slow things down.)

    max_depth=8,              # Maximum depth of each tree.
                              # Limits how many decision splits each tree can make.
                              # Shallow trees (e.g., depth=8) are easier to interpret and prevent overfitting.

    min_samples_split=5,      # Minimum number of samples required to split an internal node.
                              # Prevents the model from creating very small, unreliable splits.

    min_samples_leaf=4,       # Minimum number of samples required to be at a leaf node.
                              # Each final decision (leaf) must represent at least 4 training samples,
                              # which improves generalization and model stability.

    n_jobs=-1,                # Use all available CPU cores for parallel training to speed things up.

    random_state=42           # Random seed for reproducibility.
                              # Ensures you get the same results every time you run this code.
)


rf.fit(X_train, y_train)

proba = rf.predict_proba(X_test)[:, 1]
auc = roc_auc_score(y_test, proba)
print(f"Validation ROC-AUC: {auc:.3f}")

print("\nClassification report (threshold=0.5):")
pred = (proba >= 0.5).astype(int)
print(classification_report(y_test, pred))

disp = ConfusionMatrixDisplay.from_predictions(y_test, pred)
plt.title("Confusion Matrix (threshold=0.5)")
plt.show()



## 7. Explainability with SHAP


In [None]:
import shap

# Newer, unified SHAP API â€” automatically picks the right explainer
explainer = shap.Explainer(rf, X_train)

# Compute SHAP values (returns a modern Explanation object)
shap_values_ebm = explainer(X_test)


In [None]:
# Handle output format across SHAP versions:
# - If list, index 1 is positive class
# - If numpy array, it's already the correct shape
if isinstance(shap_values_ebm, list):
    shap_vals_pos = shap_values_ebm[1]
else:
    shap_vals_pos = shap_values_ebm

expected_value = explainer.expected_value[1] if isinstance(explainer.expected_value, (list, np.ndarray)) else explainer.expected_value
print(f"\nExpected value (baseline): {expected_value:.3f}")


In [None]:

print( f"""
Business interpretation:
- The modelâ€™s average predicted probability of default is about {expected_value:.1%}.
- This represents the 'baseline risk' â€” the expected default rate across all clients,
  before considering any personal or behavioral features.
- Each customerâ€™s SHAP values then push this baseline up or down depending on their profile:
    * Positive total SHAP effect â†’ higher-than-average default risk
    * Negative total SHAP effect â†’ lower-than-average default risk
"""
)



### 7.1 Global Feature Importance (SHAP Summary Plot)
- **What it shows:** Overall impact of each feature on the model's predictions.
- **Reading tips:** Higher position = more important; color indicates feature value (high vs low); horizontal spread indicates effect magnitude.


In [None]:
shap.summary_plot(shap_values_ebm[:, :, 1], X_test)


### 7.2 Local Explanations (Per-Applicant)

We'll inspect a single applicant's prediction with an **interactive force plot** (requires JS).  
This shows how each feature pushes the prediction **higher** (towards default) or **lower** (towards non-default).


In [None]:

# Step 3 â€” Pick one sample to visualize
i = 0  # change this index to view other clients

shap.plots.force(
    explainer.expected_value[1],      # baseline for "default" class
    shap_values_ebm.values[i, :, 1],  # SHAP values for that instance and class
    X_test.iloc[i, :],                # actual feature values for that client
show=True
)


In [None]:

# Step 3 â€” Pick one sample to visualize
i = 28  # change this index to view other clients

shap.plots.force(
    explainer.expected_value[1],      # baseline for "default" class
    shap_values_ebm.values[i, :, 1],  # SHAP values for that instance and class
    X_test.iloc[i, :],                # actual feature values for that client
    show=True
)