In [1]:
import sys
import os

# Obtener el directorio actual
current_directory = os.getcwd()

# Subir dos niveles al directorio raíz, donde se encuentra la carpeta 'interpretml'
root_path = os.path.abspath(os.path.join(current_directory, '..', '..', '..'))

sys.path.append(root_path)

In [2]:
import pandas as pd
import numpy as np
import interpret
from interpret import show

In [3]:
iris = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)
iris.columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

# Create a binary problem
iris['species'] = np.where(iris['species'] == 'Iris-setosa', 1, 0)

X = iris.drop('species', axis=1)
y = iris['species']

# Split
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [4]:
# import linear from interpret
from interpret.glassbox._naivebayes import NaiveBayesClassifier
from interpret.glassbox._linear import LogisticRegression

nb = NaiveBayesClassifier()
nb.fit(X_train, y_train)

lr = LogisticRegression()
lr.fit(X_train, y_train)

<interpret.glassbox._linear.LogisticRegression at 0x183d4767a60>

In [5]:
print(X_test.shape)
X_test.sample(3)

(30, 4)


Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width
145,6.7,3.0,5.2,2.3
128,6.4,2.8,5.6,2.1
9,4.9,3.1,1.5,0.1


In [6]:
nb.predict(X_test)

array([0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 1], dtype=int64)

In [7]:
print(nb._model().theta_)
print(nb._model().var_)

[[6.21875 2.86625 4.865   1.6525 ]
 [4.99    3.44    1.4525  0.2425 ]]
[[0.44427344 0.10923594 0.663775   0.17599375]
 [0.1239     0.1549     0.03299375 0.01144375]]


In [8]:
nb_local = nb.explain_local(X_test, y_test)
show(nb_local)

In [9]:
lr.predict(X_test)

array([0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 1], dtype=int64)

In [10]:
lr_local = lr.explain_local(X_test, y_test)
show(lr_local)

## Discretize features

In [11]:
from interpret.glassbox._categoricalnaivebayes import NaiveBayesClassifier as CategoricalNaiveBayesClassifier

In [12]:
from sklearn.preprocessing import KBinsDiscretizer

kbd = KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='uniform', subsample=200)
X_train_discrete = pd.DataFrame(kbd.fit_transform(X_train), columns=X_train.columns)
X_test_discrete = pd.DataFrame(kbd.transform(X_test), columns=X_test.columns)

In [13]:
CNB = CategoricalNaiveBayesClassifier()
CNB.fit(X_train_discrete, y_train)

<interpret.glassbox._categoricalnaivebayes.NaiveBayesClassifier at 0x183d36ebd00>

In [14]:
CNB.model.category_count_

[array([[ 2., 14., 33., 21., 10.],
        [15., 23.,  2.,  0.,  0.]]),
 array([[ 9., 34., 35.,  2.,  0.],
        [ 1.,  1., 20., 14.,  4.]]),
 array([[ 0.,  1., 25., 36., 18.],
        [40.,  0.,  0.,  0.,  0.]]),
 array([[ 0.,  7., 33., 24., 16.],
        [39.,  1.,  0.,  0.,  0.]])]

In [15]:
(CNB.predict(X_test_discrete) == y_test).sum() / len(y_test)

1.0

In [16]:
CNBlocal = CNB.explain_local(X_test_discrete, y_test)
show(CNBlocal)

Instance 0
0.6931471805599453
[0.4        0.41176471 0.43529412 0.4       ] [0.06666667 0.04444444 0.02222222 0.02222222]
[1.79175947 2.22621211 2.97492915 2.89037176]

Instance 1
0.6931471805599453
[0.4        0.03529412 0.01176471 0.01176471] [0.06666667 0.33333333 0.91111111 0.88888889]
[ 1.79175947 -2.24542668 -4.34956083 -4.32486822]

Instance 2
0.6931471805599453
[0.12941176 0.41176471 0.22352941 0.2       ] [0.02222222 0.04444444 0.02222222 0.02222222]
[1.76190651 2.22621211 2.30845021 2.19722458]

Instance 3
0.6931471805599453
[0.4        0.41176471 0.43529412 0.4       ] [0.06666667 0.04444444 0.02222222 0.02222222]
[1.79175947 2.22621211 2.97492915 2.89037176]

Instance 4
0.6931471805599453
[0.25882353 0.41176471 0.43529412 0.4       ] [0.02222222 0.04444444 0.02222222 0.02222222]
[2.45505369 2.22621211 2.97492915 2.89037176]

Instance 5
0.6931471805599453
[0.17647059 0.42352941 0.01176471 0.01176471] [0.53333333 0.46666667 0.91111111 0.88888889]
[-1.1059924  -0.09699227 -4.3