In [1]:
import numpy as np
import os
from time import time
import datetime
import gc

from sklearn.metrics import accuracy_score, f1_score, auc
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz
from matplotlib import pyplot as plt
from matplotlib import colors
import pydotplus


import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorboard.plugins.hparams import api as hp

from read_dataset_for_constraint import switch_dataset, data_to_interpret

from utils import plot_boundaries_hyperrect
from sTGMA import SoftTruncatedGaussianMixtureAnalysis

from black_box import BlackBoxNN
from config import  config_params, hyper_params

In [43]:
 
dataset_name = "pima_indian_diabetes"    

X_train, y_train, X_test, y_test, y_train_onehot, y_test_onehot, scaler, color_map, column_names = \
        data_to_interpret(dataset_name, 0)

In [44]:
X_train.shape

(614, 8)

In [45]:
column_names

['Pregnancies',
 'Glucose',
 'BloodPressure',
 'SkinThickness',
 'Insulin',
 'BMI',
 'DiabetesPedigreeFunction',
 'Age',
 'Class']

In [48]:
n_components = 3
_lambda = 5
arch = [128, 128]

black_box = tf.saved_model.load(f"./results/holdout/{dataset_name}/components_{n_components}_lambda_{_lambda}_toto/fold_0/bb_weights")

stgma = tf.saved_model.load(f"./results/holdout/{dataset_name}/components_{n_components}_lambda_{_lambda}_toto/fold_0/stgma_weights")




In [49]:
lower = stgma.lower.numpy()
upper = stgma.upper.numpy()

np.set_printoptions(precision=3)
inv_lower = np.around(scaler.inverse_transform(lower),2)
inv_upper = np.around(scaler.inverse_transform(upper),2)

columns_lower = [f"{column_names[i]}_{k}{c}_less_{str(inv_lower[c,k,i])}" for c in range(lower.shape[0]) for k in range(lower.shape[1]) for i in range(lower.shape[2])]
columns_upper = [f"{column_names[i]}_{k}{c}_less_{str(inv_upper[c,k,i])}" for c in range(lower.shape[0]) for k in range(lower.shape[1]) for i in range(lower.shape[2])]


def binarise(lower, upper, X):
    alpha1 = np.reshape(lower, (lower.shape[0]*lower.shape[1], -1))
    alpha2 = np.reshape(upper, (upper.shape[0]*lower.shape[1], -1))
    
    X_low = (np.expand_dims(X, axis = 1) < alpha1) + 0

    X_upp = (np.expand_dims(X, axis = 1) < alpha2) + 0
    
    X_ = np.concatenate((X_low, X_upp), axis=1)
    
    return np.reshape(X_, (X_.shape[0], X_.shape[1]*X_.shape[2]))

In [50]:
X_train_exp = binarise(lower, upper, X_train)

X_test_exp = binarise(lower, upper, X_test)

bb_y_train = np.argmax(black_box.predict(X_train).numpy(), axis = -1)
bb_y_test =  np.argmax(black_box.predict(X_test).numpy(), axis = -1)
accuracy_score(bb_y_test, y_test)


0.7402597402597403

In [143]:
n_min = round(X_train.shape[0]/90)
n_min

7

In [183]:

K = 4

dt1 = DecisionTreeClassifier(max_depth = K, min_samples_leaf=n_min)

dt1.fit(X_train, y_train)

dt2 = DecisionTreeClassifier(max_depth=K, min_samples_leaf=n_min)

dt2.fit(X_train_exp, bb_y_train)

dt4 = DecisionTreeClassifier(max_depth=K, min_samples_leaf=n_min)

dt4.fit(X_train, bb_y_train)


DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=4, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=7, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')

In [184]:
accuracy_score(bb_y_test, y_test)

0.7402597402597403

In [185]:
accuracy_score(dt1.predict(X_test), y_test)


0.7142857142857143

In [186]:
accuracy_score(dt4.predict(X_test), y_test)

0.7077922077922078

In [187]:
accuracy_score(dt2.predict(X_test_exp), y_test)

0.7077922077922078

In [188]:
accuracy_score(dt4.predict(X_test), bb_y_test)

0.9545454545454546

In [189]:
accuracy_score(dt2.predict(X_test_exp), bb_y_test)

0.8766233766233766

array([0, 0, 0, ..., 0, 0, 0])

In [204]:
graph=export_graphviz(dt1, filled=True, feature_names=column_names[:-1])
graph=pydotplus.graph_from_dot_data(graph)
graph.write_png('dt1.png') 

True

In [205]:
graph = export_graphviz(dt2, filled=True, feature_names=columns_upper + columns_upper)
graph=pydotplus.graph_from_dot_data(graph)
graph.write_png('dt2.png') 

True

In [206]:
graph = export_graphviz(dt4, filled=True, feature_names=column_names[:-1])
graph=pydotplus.graph_from_dot_data(graph)
graph.write_png('dt4.png') 

True

In [None]:
import weka.core.jvm as jvm
jvm.start()


In [157]:
from weka.core.dataset import create_instances_from_matrices
from weka.classifiers import Classifier, Evaluation
import weka.plot.classifiers as plot_cls
import weka.plot.graph as plot_graph
from weka.filters import Filter
cls = Classifier(classname="weka.classifiers.rules.JRip", options=["-N", f"{str(n_min)}", "-P"])

In [164]:
dataset = create_instances_from_matrices(X_train, bb_y_train, name="generated from matrices")
dataset.class_is_last()


nominal = Filter(classname="weka.filters.unsupervised.attribute.NumericToNominal", options=["-R", "last"])
nominal.inputformat(dataset)
nominaldata1 = nominal.filter(dataset)
nominaldata1.class_is_last()

cls.build_classifier(nominaldata1)

In [165]:
print(cls)


JRIP rules:

(x2 >= 0.680699) and (x6 >= -0.252935) => y=1 (117.0/0.0)
(x2 >= 0.437074) and (x1 >= 0.333032) => y=1 (27.0/0.0)
(x2 >= 0.162996) and (x6 >= 0.512967) and (x4 >= 0.356128) => y=1 (17.0/0.0)
(x2 >= 1.046137) => y=1 (8.0/0.0)
(x6 >= -0.559295) and (x2 >= 0.650246) and (x1 >= -0.54768) => y=1 (7.0/0.0)
 => y=0 (438.0/15.0)

Number of Rules : 6



In [168]:
dataset = create_instances_from_matrices(X_test, bb_y_test, name="Test")
dataset.class_is_last()


nominal = Filter(classname="weka.filters.unsupervised.attribute.NumericToNominal", options=["-R", "last"])
nominal.inputformat(dataset)
nominaldata1 = nominal.filter(dataset)
nominaldata1.class_is_last()

In [169]:
evaluation = Evaluation(nominaldata1)
evl = evaluation.test_model(cls, nominaldata1)

print(evaluation.summary())


Correctly Classified Instances         146               94.8052 %
Incorrectly Classified Instances         8                5.1948 %
Kappa statistic                          0.8709
Mean absolute error                      0.0749
Root mean squared error                  0.2259
Relative absolute error                 18.5389 %
Root relative squared error             50.3457 %
Total Number of Instances              154     



In [88]:
plt.rcParams["figure.figsize"] = (9,6.75)

X1 = X_train[:,0]
X2 = X_train[:,1]
steps = 1000
cmap = colors.ListedColormap(list(color_map.values())[:len(np.unique(y_train))])
# Define region of interest by data limits
deltaX = (max(X1) - min(X1))/10
deltaY = (max(X2) - min(X2))/10

xmin, xmax = min(X1) - deltaX, max(X1) + deltaX
ymin, ymax = min(X2) - deltaY, max(X2) + deltaY

x_span = np.linspace(xmin, xmax, steps)
y_span = np.linspace(ymin, ymax, steps)
xx, yy = np.meshgrid(x_span, y_span)

# Make predictions across region of interest
labels_dt = dt4.predict(np.c_[xx.ravel(), yy.ravel()])
#labels_dt = np.argmax(labels_bb, axis = 1)

#plt.subplot(2, 1, 1)
z1 = labels_dt.reshape(xx.shape)
ranges = np.linspace(z1.min(), z1.max(), len(color_map.values())+1)
norm = colors.BoundaryNorm(ranges, cmap.N)


plt.contourf(xx, yy, z1, alpha=0.2, cmap = cmap, norm=norm)

plt.scatter(X1, X2, c = [color_map[y_train[i]] for i in range(X_train.shape[0])],  edgecolor='k', lw=0, cmap="Set1")
currentAxis = plt.gca()

plt.savefig(f"dt_post_hoc.png")
plt.clf()

<Figure size 648x486 with 0 Axes>

In [90]:
X1 = X_train[:,0]
X2 = X_train[:,1]
steps = 1000
plt.rcParams["figure.figsize"] = (9,6.75)
cmap = colors.ListedColormap(list(color_map.values())[:len(np.unique(y_train))])
# Define region of interest by data limits
deltaX = (max(X1) - min(X1))/10
deltaY = (max(X2) - min(X2))/10

xmin, xmax = min(X1) - deltaX, max(X1) + deltaX
ymin, ymax = min(X2) - deltaY, max(X2) + deltaY

x_span = np.linspace(xmin, xmax, steps)
y_span = np.linspace(ymin, ymax, steps)
xx, yy = np.meshgrid(x_span, y_span)

# Make predictions across region of interest
labels_dt = dt2.predict(binarise(lower, upper,np.c_[xx.ravel(), yy.ravel()]))
#labels_dt = np.argmax(labels_bb, axis = 1)

#plt.subplot(2, 1, 1)
z1 = labels_dt.reshape(xx.shape)
ranges = np.linspace(z1.min(), z1.max(), len(color_map.values())+1)
norm = colors.BoundaryNorm(ranges, cmap.N)


plt.contourf(xx, yy, z1, alpha=0.2, cmap = cmap, norm=norm)

plt.scatter(X1, X2, c = [color_map[y_train[i]] for i in range(X_train.shape[0])],  edgecolor='k', lw=0, cmap="Set1")
currentAxis = plt.gca()

plt.savefig(f"dt_post_hoc_with_regul.png")
plt.clf()

<Figure size 648x486 with 0 Axes>

In [28]:
X_low = (np.expand_dims(X_train, axis = 1) < alpha1) + 0

X_upp = (np.expand_dims(X_train, axis = 1) < alpha2) + 0