In [1]:
import pandas as pd
from interpret.glassbox import ExplainableBoostingClassifier
from interpret import show
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

import plotly.express as px

In [2]:
# imports for onnx conversion and inference
import onnx
import ebm2onnx
import onnxruntime as rt
import numpy as np
import tempfile

# Train a classfication model

In [3]:
# load dataset
df = pd.read_csv('titanic_train.csv')
df = df.dropna()

In [4]:
# train the model
feature_columns = ['Age', 'Fare', 'Pclass', 'Embarked']
label_column = "Survived"

y = df[[label_column]]
le = LabelEncoder()
y_enc = le.fit_transform(y)
x = df[feature_columns]
x_train, x_test, y_train, y_test = train_test_split(x, y_enc)
ebm = ExplainableBoostingClassifier(
    interactions=2,
    feature_types=['continuous', 'continuous', 'continuous','categorical']
)
ebm.fit(x_train, y_train)


A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().



ExplainableBoostingClassifier(feature_names=['Age', 'Fare', 'Pclass',
                                             'Embarked', 'Age x Fare',
                                             'Fare x Embarked'],
                              feature_types=['continuous', 'continuous',
                                             'continuous', 'categorical',
                                             'interaction', 'interaction'],
                              interactions=2)

In [5]:
# convert the model to onnx
onnx_model = ebm2onnx.to_onnx(
    model=ebm,
    explain=True,  # Generate a dedicated output for local explanations
    dtype={
        'Age': 'double',
        'Fare': 'double',
        'Pclass': 'int',
    },
    name="ebm",
)

_, filename = tempfile.mkstemp()
onnx.save_model(onnx_model, filename)

In [6]:
# predict on test set with ONNX-Runtime
sess = rt.InferenceSession(filename)
onnx_pred = sess.run(None, {
    'Age': x_test['Age'].values,
    'Fare': x_test['Fare'].values,
    'Pclass': x_test['Pclass'].values,
    'Embarked': x_test['Embarked'].values,
})

# Local explanation

In [34]:
def show_onnx_local_explanation(predictions, sample_to_explain):
    scores = predictions[1]    
    scores = scores[sample_to_explain][:,0]
    abs_scores = np.abs(scores)
    sorted_indices = np.argsort(abs_scores)
    scores = scores[sorted_indices]
    
    colors = [s > 0 for s in scores]

    fig = px.bar(  
        scores,
        color=colors,
        orientation='h',
        color_discrete_map={
            True: '#FF7F0E',
            False: '#1F77B4',
        },
        text=[ebm.feature_names[i] for i in sorted_indices],
        height=300,
    )

    fig.update(layout_showlegend=False)
    fig.show()

In [8]:
# For reference, we plot the local explanations as provided by interpretml
ebm_local = ebm.explain_local(x_test, y_test)
show(ebm_local)

In [35]:
# The ONNX predictions contain also the local explanation
# We can display the same plots.

show_onnx_local_explanation(onnx_pred, 4)