In [None]:
# Cell 1: Imports and loading model/data
import shap
import joblib
import pandas as pd
import matplotlib.pyplot as plt
from src.data_preprocessing import load_data, preprocess_data

# Load model and preprocessor
model = joblib.load("../outputs/models/printability_model.pkl")  # ✅ Updated filename
preprocessor = joblib.load("../outputs/models/preprocessor.pkl")

# Load and preprocess data
df = load_data("../data/extended_printability_dataset_with_gelatin_silk.csv")  # ✅ Updated filename
target_column = "Printable"
X_train, X_test, y_train, y_test, _ = preprocess_data(df, target_column)


In [None]:
# Cell 2: SHAP Beeswarm Plot
explainer = shap.TreeExplainer(model)

# Handle sparse input if applicable
X_sample = X_test[:100].toarray() if hasattr(X_test, "toarray") else X_test[:100]

# Compute SHAP values
shap_values = explainer.shap_values(X_sample)

# Beeswarm plot
shap.summary_plot(shap_values, X_sample, show=False)
plt.tight_layout()
plt.show()


In [None]:
# Cell 3: SHAP Bar Plot
shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False)
plt.tight_layout()
plt.show()
