In [1]:
from smt.surrogate_models import KRG
from smt_explainability.problems import MixedCantileverBeam
from smt.design_space import (
    DesignSpace,
    FloatVariable,
    CategoricalVariable,
)
from smt.surrogate_models import (
    KPLS,
    MixIntKernelType,
    MixHrcKernelType,
)
from smt.applications.mixed_integer import MixedIntegerKrigingModel

from smt_explainability.shap.shap_display import ShapDisplay
from smt_explainability.shap.shap_feature_importance_display import ShapFeatureImportanceDisplay

from sklearn.metrics import mean_squared_error
import numpy as np
import time

In [2]:
ndoe = 300
n_train = int(0.8 * ndoe)
fun = MixedCantileverBeam()
# Name of the features
feature_names = [r'$\tilde{I}$', r'$L$', r'$S$']
# Index for categorical features
categorical_feature_indices = [0]
# Design space
ds = DesignSpace([
    CategoricalVariable(values=[str(i + 1) for i in range(12)]),
    FloatVariable(10.0, 20.0),
    FloatVariable(1.0, 2.0),
])
# create mapping for the categories
categories_map = dict()
inverse_categories_map = dict()
for feature_idx in categorical_feature_indices:
    categories_map[feature_idx] = {
        i: value for i, value in enumerate(ds._design_variables[feature_idx].values)
    }
    inverse_categories_map[feature_idx] = {
        value: i for i, value in enumerate(ds._design_variables[feature_idx].values)
    }

X = fun.sample(ndoe)
y = fun(X)

X_tr, y_tr = X[:n_train, :], y[:n_train]
X_te, y_te = X[n_train:, :], y[n_train:]

class GroundTruthModel:
    def predict_values(self, X):
        return fun(X)
    
gtm = GroundTruthModel()

In [3]:
sm = MixedIntegerKrigingModel(
    surrogate=KPLS(
        design_space=ds,
        categorical_kernel=MixIntKernelType.HOMO_HSPHERE,
        hierarchical_kernel=MixHrcKernelType.ARC_KERNEL,
        theta0=np.array([4.43799547e-04, 4.39993134e-01, 1.59631650e+00]),
        corr="squar_exp",
        n_start=1,
        cat_kernel_comps=[2],
        n_comp=2,
        print_global=False,
        ),
    )


start_time = time.time()
sm.set_training_values(X_tr, np.array(y_tr))
sm.train()
print("run time (s):", time.time() - start_time)

print("Surrogate model")
y_pred = sm.predict_values(X_te)
rmse = mean_squared_error(y_te, y_pred, squared=False)
rrmse = rmse / y_te.mean()
print(f"RMSE: {rmse:.4f}")
print(f"rRMSE: {rrmse:.4f}")



run time (s): 35.65931701660156
Surrogate model
RMSE: 0.0000
rRMSE: 0.0200




In [4]:
instances = X_tr
model = sm

kernel_shap_explainer = ShapDisplay.from_surrogate_model(
    instances, 
    model, 
    X_tr, 
    feature_names=feature_names,
    categorical_feature_indices=categorical_feature_indices,
    categories_map=categories_map,
    method="kernel",
)

In [5]:
instances = X_tr
model = sm

exact_shap_explainer = ShapDisplay.from_surrogate_model(
    instances, 
    model, 
    X_tr, 
    feature_names=feature_names,
    categorical_feature_indices=categorical_feature_indices,
    categories_map=categories_map,
    method="exact",
)

In [20]:
kernel_shap_dependence_plot = kernel_shap_explainer.dependence_plot([0, 1, 2], sort_based_on_importance=False)
kernel_shap_dependence_plot.savefig("kernel_shap_dependence_mixed.png")

exact_shap_dependence_plot = exact_shap_explainer.dependence_plot([0, 1, 2], sort_based_on_importance=False)
exact_shap_dependence_plot.savefig("exact_shap_dependence_mixed.png")

In [9]:
kernel_shap_feature_importance = ShapFeatureImportanceDisplay.from_surrogate_model(
    sm, 
    X_tr, 
    feature_names=feature_names,
    method="kernel",
    categorical_feature_indices=categorical_feature_indices,
)

exact_shap_feature_importance = ShapFeatureImportanceDisplay.from_surrogate_model(
    sm, 
    X_tr, 
    feature_names=feature_names,
    method="exact",
    categorical_feature_indices=categorical_feature_indices,
)

In [13]:
kernel_shap_feature_importance_plot = kernel_shap_feature_importance.plot()
kernel_shap_feature_importance_plot.savefig("kernel_shap_importance_mixed.png")

exact_shap_feature_importance_plot_plot = exact_shap_feature_importance.plot()
exact_shap_feature_importance_plot_plot.savefig("exact_shap_importance_mixed.png")

In [14]:
kernel_shap_summary_plot = kernel_shap_explainer.summary_plot()
kernel_shap_summary_plot.savefig("kernel_shap_summary_mixed.png")

exact_shap_summary_plot = exact_shap_explainer.summary_plot()
exact_shap_summary_plot.savefig("exact_shap_summary_mixed.png")

In [17]:
feature_pairs = [(2, 1), (1, 0)]
kernel_shap_interaction_plot = kernel_shap_explainer.interaction_plot(
    feature_pairs, 
    n_color=5,
    sort_based_on_importance={0: False},
)
exact_shap_interaction_plot = exact_shap_explainer.interaction_plot(
    feature_pairs, 
    n_color=5,
    sort_based_on_importance={0: False},
)

kernel_shap_interaction_plot.savefig("kernel_shap_interaction_mixed.png")
exact_shap_interaction_plot.savefig("exact_shap_interaction_mixed.png")

In [21]:
kernel_shap_individual_plot = kernel_shap_explainer.individual_plot(index=10)
kernel_shap_individual_plot.savefig("kernel_shap_individual_mixed.png")

exact_shap_individual_plot = exact_shap_explainer.individual_plot(index=10)
exact_shap_individual_plot.savefig("exact_shap_individual_mixed.png")

  fig, ax = plt.subplots(1, 1, figsize=(length, width))
