# Interacting with ProtoDash

In this notebook we'll combine the ProtoDash and the Partial Effects to obtain feature importances on the digits classifications task.

ProtoDash was proposed in _Gurumoorthy, Karthik & Dhurandhar, Amit & Cecchi, Guillermo & Aggarwal, Charu. (2019). Efficient Data Representation by Selecting Prototypes with Importance Weights. 260-269. 10.1109/ICDM.2019.00036_.

In [1]:
import numpy  as np
import pandas as pd

# automatically differentiable implementation of numpy
import jax.numpy as jnp

from sklearn import datasets

from sklearn.model_selection import train_test_split
from IPython.display         import display, Math, Latex

import matplotlib.pyplot as plt

from itea.classification import ITEA_classifier
from itea.inspection     import *

from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import classification_report

from aix360.algorithms.protodash import ProtodashExplainer, get_Gaussian_Data

In [None]:
digits_data = datasets.load_digits()

X, y        = digits_data['data'], digits_data['target']
labels      = digits_data['feature_names']
targets     = digits_data['target_names']


X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42)

print(X_train.shape)

# Creating transformation functions for ITEA using jax.numpy
# (so we don't need to analytically calculate its derivatives)
tfuncs = {
    'id'       : lambda x: x,
    'log'      : jnp.log,
    'exp'      : jnp.exp
}

clf = ITEA_classifier(
    gens            = 1000,
    popsize         = 200,
    max_terms       = 64,
    expolim         = (0, 1),
    verbose         = 5,
    tfuncs          = tfuncs,
    labels          = labels,
    simplify_method = 'simplify_by_var',
    random_state    = 42,
).fit(X_train, y_train)

(1203, 64)
gen 	 min_fitness 	 mean_fitness 	 max_fitness 	 remaining (s)
0 	 0.1055694098088113 	 0.10556940980881134 	 0.1055694098088113 	 145min55seg
5 	 0.1055694098088113 	 0.1057148794679967 	 0.10640066500415628 	 152min47seg
10 	 0.1055694098088113 	 0.10751870324189525 	 0.18121363258520365 	 378min57seg
15 	 0.10640066500415628 	 0.11966334164588527 	 0.18786367414796343 	 508min8seg
20 	 0.10640066500415628 	 0.18519950124688278 	 0.19201995012468828 	 1209min42seg


In [None]:
final_itexpr = clf.bestsol_
final_itexpr.selected_features_

In [None]:
print(X_test.shape, y_test.reshape(-1, 1).shape)

onehot_encoder = OneHotEncoder(sparse=False)
onehot_encoded = onehot_encoder.fit_transform(
    np.hstack( (X_train, y_train.reshape(-1, 1)) ) )

explainer = ProtodashExplainer()

# call protodash explainer
# S contains indices of the selected prototypes
# W contains importance weights associated with the selected prototypes 
(W, S, _) = explainer.explain(onehot_encoded, onehot_encoded, m=len(np.unique(y_train)))

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(8,5))

# Hiding one subplot
axs[1, 2].set_visible(False)

for s, ax in zip(S, fig.axes):
    ax.imshow(X_train[s].reshape(8, 8))
    ax.set_title(f"Prototype of class {y_train[s]}")
    
plt.tight_layout()
plt.show()

In [None]:
it_explainer = ITExpr_explainer(
    itexpr=final_itexpr,
    tfuncs=tfuncs
).fit(X_train, y_train)

fig, axs = plt.subplots(2, 3, figsize=(8,5))

axs[1, 2].set_visible(False)

for s, ax in zip(S, fig.axes):
    
    importances = np.sum(
        it_explainer.average_partial_effects(X_train[s, :].reshape(1, -1)),
        axis=0
    )
    
    ax.imshow(importances.reshape(8, 8))
    ax.set_title(f"Feature importances for\nPrototype of class {y_train[s]}")
    
plt.tight_layout()
plt.show()

In [None]:
# now lets pick multiple prototypes and see the feature importance for groups

explainer = ProtodashExplainer()

(W, S, _) = explainer.explain(onehot_encoded, onehot_encoded, m=len(np.unique(y_train))*3)

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(8,5))

axs[1, 2].set_visible(False)

for class_, ax in zip(np.unique(y_train), fig.axes):
    
    prototypes_for_class = [s for s in S if y_train[s]==class_]
    
    importances = it_explainer.average_partial_effects(X_train[prototypes_for_class, :])[class_]
    
    ax.imshow(importances.reshape(8, 8))
    ax.set_title(f"Feature importances for\nPrototype of class {class_}")
    
plt.tight_layout()
plt.show()