# Imports and setup

In [None]:
import sys
import os
import pandas as pd
import shap

project_root = os.path.abspath(os.path.join(os.getcwd(), "../../"))
if project_root not in sys.path:
    sys.path.append(project_root)

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from src.explainability.shap_explainer import SHAPExplainer

# Load and prepare data

In [None]:
df = pd.read_csv("../../data/processed/cleaned_balanced_fraud_data.csv")
X = df.drop("class", axis=1)
y = df["class"]


In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)

# Train RandomForest model

In [None]:
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# SHAP for explainability

### -------- Sample 500 rows from the training set

In [None]:
X_train_sampled = X_train.sample(n=500, random_state=42)

### --------- Use sampled data for SHAP explanation only

In [None]:
explainer = SHAPExplainer(model=model, X_train=X_train_sampled, X_test=X_test, model_type="tree")
shap_values = explainer.fit()

# Global feature importance (summary plot)

In [None]:
explainer.plot_summary(plot_type="bar")

# Local explanation (force plot for a single prediction)

In [None]:
force_plot_html = shap.force_plot(
    base_value=shap_values.base_values[0],
    shap_values=shap_values.values[0],
    features=X_test.iloc[0],
    feature_names=X_test.columns,
    matplotlib=False
)
# Save as HTML file
shap.save_html("shap_force_plot_randomForest.html", force_plot_html)