# SHAP analysis

This notebook contains the [SHAP (SHapley Additive exPlanations)](https://shap.readthedocs.io/en/latest/) analysis for the Random Forest model used in the localization predictor. It visualizes the impact of features on model predictions using SHAP values.

We utilize 200 samples from the test set to compute SHAP values, which are then visualized using various SHAP plots.

In [None]:
import os

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap

In [None]:
OUT_VIS = "results/figures/shap_analysis"
os.makedirs(OUT_VIS, exist_ok=True)

In [None]:
model = joblib.load("results/models/rf_best.pkl")
rf = model.named_steps["rf"]

In [None]:
X_train = pd.read_csv("data/processed/X/train.csv").select_dtypes(include=[np.number])
X_test = pd.read_csv("data/processed/X/test.csv").select_dtypes(include=[np.number])
X_test = X_test[X_train.columns]  # align columns

In [None]:
shap.initjs()

In [None]:
X_test_shap = shap.sample(X_test, 200, random_state=42)

In [None]:
explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_test_shap)

In [None]:
if isinstance(shap_values, np.ndarray) and shap_values.ndim == 3:
    shap_values = [shap_values[:, :, i] for i in range(shap_values.shape[2])]

In [None]:
shap.summary_plot(
    shap_values,
    X_test_shap,
    feature_names=X_test_shap.columns,
    plot_type="bar",
    class_names=rf.classes_ if hasattr(rf, 'classes_') else None,
    show=False
)
plt.tight_layout()
plt.savefig(f"{OUT_VIS}/shap_summary_all_classes.png")
plt.close()

In [None]:
for cls_index in range(len(rf.classes_)):
    label = rf.classes_[cls_index]
    os.makedirs(f"{OUT_VIS}/{label.replace(' ', '_').lower()}", exist_ok=True)

In [None]:
for class_index in range(len(shap_values)):
    print(f"Generating SHAP plots for class {class_index}")
    label = (
        rf.classes_[class_index] if hasattr(rf, "classes_") else f"Class {class_index}"
    )
    label = label.replace(" ", "_").lower()
    for i in range(len(X_test_shap)):
        shap.force_plot(
            explainer.expected_value[class_index],
            shap_values[class_index][i, :],
            X_test_shap.iloc[i, :],
            matplotlib=True,
            show=False,
        )
        force_path = f"{OUT_VIS}/{label}/shap_force_sample{i}.png"
        plt.savefig(force_path, bbox_inches="tight")
        plt.close()

        shap.plots._waterfall.waterfall_legacy(
            explainer.expected_value[class_index],
            shap_values[class_index][i, :],
            X_test_shap.iloc[i, :],
            show=False,
        )
        waterfall_path = f"{OUT_VIS}/{label}/shap_waterfall_sample{i}.png"
        plt.savefig(waterfall_path, bbox_inches="tight")
        plt.close()

In [None]:
for class_index in range(len(shap_values)):
    has_classes = hasattr(rf, 'classes_')
    label = rf.classes_[class_index] if has_classes else f"Class {class_index}"
    label = label.replace(' ', '_').lower()
    shap.decision_plot(
        explainer.expected_value[class_index],
        shap_values[class_index],
        X_test_shap,
        show=False
    )
    decision_path = f"{OUT_VIS}/{label}/shap_decision.png"
    plt.savefig(decision_path, bbox_inches="tight")
    plt.close()

In [None]:
shap_interact = explainer.shap_interaction_values(X_test_shap)
shap.summary_plot(
    shap_interact,
    X_test_shap,
    feature_names=X_test_shap.columns,
    plot_type='dot'
)
plt.savefig(f"{OUT_VIS}/shap_interaction_summary.png")
plt.close()

## Speeding Up SHAP Plot Generation

(Note: This section is a sandbox for testing and may not be fully functional yet.)

To accelerate plotting over many samples and classes, consider leveraging `joblib.Parallel` to generate plots concurrently.

Below is an example of how to parallelize force and waterfall plot generation.

In [None]:
from joblib import Parallel, delayed


def gen_plots_for(i, class_index, label):
    shap_vals_inst = shap_values[class_index][i : i + 1, :]
    feat_inst = X_test_shap.iloc[i : i + 1, :]
    fig = shap.plots.force(
        explainer.expected_value[class_index],
        shap_vals_inst,
        feat_inst,
        matplotlib=True,
        show=False,
    )
    fig.savefig(f"{OUT_VIS}/{label}/shap_force_sample{i}.png", bbox_inches="tight")
    plt.close(fig)

    shap.plots._waterfall.waterfall_legacy(
        explainer.expected_value[class_index],
        shap_values[class_index][i, :],
        X_test_shap.iloc[i, :],
        show=False,
    )
    fig2 = plt.gcf()
    fig2.savefig(f"{OUT_VIS}/{label}/shap_waterfall_sample{i}.png", bbox_inches="tight")
    plt.close(fig2)


indices = shap.sample(list(range(len(X_test_shap))), 20, random_state=1)
class_idx = 0
label = rf.classes_[class_idx].replace(" ", "_").lower()
Parallel(n_jobs=4)(delayed(gen_plots_for)(i, class_idx, label) for i in indices)