In [None]:
import sys
sys.path.insert(0, '../../')

In [None]:
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from holisticai.datasets import load_adult
from holisticai.explainability import Explainer
from holisticai.efficacy.metrics import classification_efficacy_metrics

In [None]:
from holisticai.datasets import load_adult

# Dataset
dataset = load_adult()

# Dataframe
df = pd.concat([dataset["data"], dataset["target"]], axis=1)
protected_variables = ["sex", "race"]
output_variable = ["class"]

# Simple preprocessing
y = df[output_variable].replace({">50K": 1, "<=50K": 0})
X = pd.get_dummies(df.drop(protected_variables + output_variable, axis=1))
group = ["sex"]
group_a = df[group] == "Female"
group_b = df[group] == "Male"
data = [X, y, group_a, group_b]

# Train test split
dataset = train_test_split(*data, test_size=0.2, shuffle=True)
train_data = dataset[::2]
test_data = dataset[1::2]

In [None]:
#X.hist(bins=10, figsize=(10, 10), color = 'mediumslateblue')

In [None]:
from holisticai.bias.plots import correlation_matrix_plot

correlation_matrix_plot(X, target_feature='age', size = (12,7))

In [None]:
from sklearn.ensemble import GradientBoostingClassifier
import numpy as np
seed = np.random.seed(42) # set seed for reproducibility
# simple preprocessing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed) # train test split


model = GradientBoostingClassifier() # instantiate model
model.fit(X_train, y_train) # fit model

y_pred = model.predict(X_test) # compute predictions

# compute efficacy metrics
classification_efficacy_metrics(y_test, y_pred)

# Global Explainability Metrics (based on Permutation Feature Importance)

In [None]:
# permutation feature importance
permutation_explainer = Explainer(based_on='feature_importance',
                      strategy_type='permutation',
                      model_type='binary_classification',
                      model = model, 
                      x = X, 
                      y = y)

In [8]:
permutation_explainer.metrics()

KeyboardInterrupt: 

In [None]:
permutation_explainer.partial_dependence_plot()

In [None]:
permutation_explainer.metrics(detailed=True)

In [None]:
# Contrast Whole Importance vs. Group Importance
# -> Order -> mean positions
# -> Range -> match range of position 
# -> Similarity -> compute similarity
# -> e.g. we can see that Q0-Q1 and Q2-Q3 strong changes in their position but their importance weights maintains a high similarity with the whole model.
# -> e.g. we can see that Q1-Q2 and Q3-Q4 small changes in their position and their importance weights maintains a high similarity with the whole model.
permutation_explainer.contrast_visualization(show_connections=False)
# TODO separate show connections sin a second matrix

In [None]:
permutation_explainer.bar_plot(max_display=10)

In [None]:
permutation_explainer.feature_importance_table(sorted_by='Global', top_n=5)

# Global Explainability metrics (based on Surrogate Model)

In [None]:
# surrogate feature importance
surrogate_explainer = Explainer(based_on='feature_importance',
                      strategy_type='surrogate',
                      model_type='regression',
                      model = model, 
                      x = X, 
                      y = y)

In [None]:
surrogate_explainer.metrics()

In [None]:
_,ax=plt.subplots(figsize=(15,5))
surrogate_explainer.partial_dependence_plot(ax=ax, kind="both") # kind: [individual,average,both]

In [None]:
surrogate_explainer.bar_plot(max_display=5)

In [None]:
surrogate_explainer.feature_importance_table(sorted_by='Global', top_n=10)

In [None]:
_,ax = plt.subplots(figsize=(15,3))
_ = surrogate_explainer.tree_visualization('sklearn', ax=ax)

In [None]:
surrogate_explainer.tree_visualization('graphviz')

In [None]:
vis = surrogate_explainer.tree_visualization('dtreeviz', scale=2)
vis

# Local Explainability Metrics (based on Lime)

In [None]:
# lime feature importance
lime_explainer = Explainer(based_on='feature_importance',
                      strategy_type='lime',
                      model_type='regression',
                      model = model, 
                      x = X, 
                      y = y)

In [None]:
lime_explainer.metrics(detailed=True)

In [None]:
lime_explainer.bar_plot(max_display=10)

In [None]:
lime_explainer.show_importance_stability()

In [None]:
lime_explainer.show_data_stability_boundaries(top_n=10, figsize=(15,5))

In [None]:
lime_explainer.show_features_stability_boundaries(figsize=(15,5))