In [1]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from aix360.algorithms.protodash import ProtodashExplainer
from lime.lime_tabular import LimeTabularExplainer
from lime.submodular_pick import SubmodularPick

import warnings
warnings.filterwarnings("ignore")

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [2]:
def extract_W_multiclass(coef, xs, ys, coo=True):
    W = [{} for _ in xs]
    features = (xs.tocoo().row, xs.tocoo().col) if coo else np.where(xs != None)
    for (r, c) in zip(features[0], features[1]):
        W[r][c] = coef[ys[r]][c]
    return W

def get_feature_importance(W):
    importance = {}
    for ins in W:
        for j in ins:
            importance[j] = importance.get(j, 0.0) + np.abs(ins[j])
    for j in importance:
        importance[j] = np.sqrt(importance[j])
    return importance

def get_used(feature_dict, top=10):
    # only consider top 10 features as used
    values = [(np.abs(v), k) for k, v in feature_dict.items()]
    values.sort(reverse=True)
    return [i for _, i in values[:top]]

def lime_objective(results, W, importance, top=10):
    obj, covered = 0, set()
    for index in results:
        for j in get_used(W[index], top):
            if j not in covered:
                obj += importance[j]
                covered.add(j)
    return obj, covered

def greedy_sp_search_multiclass(results, W, importance, budget, ys, top=10):
    flattened = [index for _class in results for index in _class]
    class_len = list(map(len, results))
    obj, covered = lime_objective(flattened, W, importance, top)
    max_diff, max_index = -1, -1
    for i in range(len(W)):
        if i in flattened or class_len[ys[i]] >= budget[ys[i]]:
            continue
        diff = 0
        for j in get_used(W[i], top):
            if j not in covered:
                diff += importance[j]
        if diff > max_diff:
            max_diff, max_index = diff, i
    return max_index, obj

def gready_sp_multiclass(W, budget, ys, top=10):
    importance = get_feature_importance(W)
    results = [[] for _ in range(np.unique(ys).size)]
    while np.any(list(map(len, results)) < np.array(budget)):
        max_index, current_obj = greedy_sp_search_multiclass(results, W, importance, budget, ys, top)
        if max_index >= 0:
            results[ys[max_index]].append(max_index)
#             print(max_index)
        else:
            break
    return results, current_obj, importance

In [3]:
K=5
w_e = np.array([0.11010821, 0.1750904, 0, 0, 0, 0.5072202 , 0.20758119, 0])

fids, inputs, labels, embeds, preds = pickle.load(open('data/prostatex/train_findings_emb10.pkl', "rb"))
X_train = embeds
y_train = labels
p_train = preds
f_train = fids

fids, inputs, labels, embeds, preds = pickle.load(open('data/prostatex/valid_findings_emb10.pkl', "rb"))
X_valid = embeds
y_valid = labels
p_valid = preds
f_valid = fids

In [4]:
full = KNeighborsClassifier(n_neighbors=K)
full.fit(X_train, y_train)
full.score(X_valid, y_valid)

0.6020408163265306

In [5]:
np.random.seed(42)
scores = []
for i in range(100):
    index = np.random.choice(range(len(X_train)), 20)
    rand = KNeighborsClassifier(n_neighbors=K)
    rand.fit(X_train[index], y_train[index])
    scores.append(rand.score(X_valid, y_valid))
np.mean(scores), np.std(scores)

(0.5691836734693877, 0.08757194935614838)

In [6]:
protodash = ProtodashExplainer()
_, index, _ = protodash.explain(X_train, X_train, m=20)
proto = KNeighborsClassifier(n_neighbors=K)
proto.fit(X_train[index], y_train[index])
proto.score(X_valid, y_valid)

0.47959183673469385

In [33]:
lr = LogisticRegression(random_state=42)
lr.fit(X_train, y_train)
# explainer = LimeTabularExplainer(X_train)
# splime = SubmodularPick(explainer, X_train, lr.predict_proba, method='full', num_features=10, num_exps_desired=5)
coef = np.vstack([lr.coef_, -lr.coef_])
W = extract_W_multiclass(coef, X_train, y_train, coo=False)
budget = [10] * 2
sp_results, current_obj, importance = gready_sp_multiclass(W, budget, y_train, 100)
index = sp_results[0] + sp_results[1]
splime = KNeighborsClassifier(n_neighbors=K)
splime.fit(X_train, y_train)
splime.score(X_valid, y_valid)

0.6020408163265306