In [1]:
%load_ext autoreload
%autoreload 2

In [29]:
from datasets import Dataset, DatasetDict
import numpy as np
import pandas as pd
# import torch
from transformers import (AutoTokenizer, AutoModelForSequenceClassification)
from transformers import BertForSequenceClassification  # Check

# Import data

In [3]:
max_seq_length = 128

In [4]:
tokenizer_name = "bert-base-cased"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def preprocess_text(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_seq_length)

# Import model

In [6]:
model_name = "g8a9/bert-base-cased_ami18"

In [7]:
model = AutoModelForSequenceClassification.from_pretrained(model_name).eval()
effective_model = BertForSequenceClassification.from_pretrained(model_name).eval()

# nlxplain

In [8]:
from nlxplain import Explainer

In [10]:
exp = Explainer(model, tokenizer)

In [20]:
text = "You are a woman"
exp.classify(text)
print(f"Importance occlusion:\n {exp.compute_occlusion_importance(text).numpy()}")
exp.compute_table(text, target = 1)

IDX: You are a woman
Text: You are a woman
Probabilities: tensor([[0.2007, 0.7993]])
Prediction: 1
Importance occlusion:
 [-0.01087821  0.00281805  0.01252252 -0.51788306]


tokens,You,are,a,woman
G,0.159478,0.163659,0.117216,0.341822
GxI,-0.20149,0.308628,-0.051436,-0.271019
IG,0.306092,0.34199,0.082422,-0.269495
SHAP,0.041938,-0.097407,-0.067262,0.793392
LIME,0.191277,-0.070717,0.269628,0.290803


# Evaluate explanations

In [44]:
text = "You are a woman"
explanations = exp.compute_table(text, target = exp.get_predicted_label(text))
tokens = list(explanations.columns)
soft_score_explanation = explanations.loc["SHAP"].values
print(soft_score_explanation)
import seaborn as sns
palette = sns.diverging_palette(240, 10, as_cmap=True)
explanations.style.background_gradient(axis=1, cmap=palette, vmin=-1, vmax=1)

[ 0.04193844 -0.09740739 -0.06726249  0.79339168]


tokens,You,are,a,woman
G,0.159478,0.163659,0.117216,0.341822
GxI,-0.20149,0.308628,-0.051436,-0.271019
IG,0.306092,0.34199,0.082422,-0.269495
SHAP,0.041938,-0.097407,-0.067262,0.793392
LIME,0.191277,-0.070717,0.269628,0.290803


In [43]:
from nlxplain import Evalutator

evalt = Evalutator(exp)
thresholds = [1, 2, 3, 4]
soft_score_explanation = explanations.loc["SHAP"].values
c = evalt.compute_comprehensiveness_ths(text, soft_score_explanation, thresholds, based_on = "k")
s = evalt.compute_sufficiency_ths(text, soft_score_explanation, thresholds, based_on = "k")

# Temporary form
print(c)
print(s)

{1: ([3], 0.5178830623626709), 2: ([3, 0], 0.6923361420631409)}
{1: ([3], -0.013836026191711426), 2: ([3, 0], -0.02977198362350464)}


Compute the 
- **taud_loo**: kendall tau distance between the explanation scores and the occlusion importance (leave one out)
- **AOPC comprehensiveness** (ERASER) -> area under the perturbation curve of the comprehensiveness
- **AOPC sufficiency** (ERASER) -> area under the perturbation curve of the sufficiency


Comprehensiveness: f(x) - f(x\r) where r is the rationale, the important tokens of the explanation. The higher the better

Sufficiency: f(x) - f(r). Close to 0 is better.


For each measure, we also provide the ranking among the explainers.

In the following, we consider the following approaches to extract the "rationale", the most important tokens:
    
- top k: based_on = "k"   --> take only the top k elements, varying k.
- percentage: take top 1%, 5%, 10%, 20% and 50% of tokens as in ERASER, 0, 10, 20, ..., 100% as in Atanasova 
- threshold: take elements greater than a threshold. 

## Top k (positive)

In [53]:
thresholds = [1,2,3,4]
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, based_on = "k", only_pos = True)
style_df

tokens,You,are,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r
G,0.159478,0.163659,0.117216,0.341822,0.472835,3,-0.00558,3,0.166667,1
GxI,-0.20149,0.308628,-0.051436,-0.271019,-0.002818,4,0.701342,5,0.833333,4
IG,0.306092,0.34199,0.082422,-0.269495,-0.003,5,0.627134,4,0.666667,3
SHAP,0.041938,-0.097407,-0.067262,0.793392,0.60511,1,-0.021804,1,0.166667,1
LIME,0.191277,-0.070717,0.269628,0.290803,0.591538,2,-0.010793,2,0.333333,2


## Percentage (positive)

In [56]:
thresholds = np.arange(0, 1.1, 0.1)
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, based_on = "perc", only_pos = False, style_df=True)
style_df

tokens,You,are,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r
G,0.159478,0.163659,0.117216,0.341822,0.52806,3,-0.008643,3,0.166667,1
GxI,-0.20149,0.308628,-0.051436,-0.271019,-0.021811,5,0.678475,5,0.833333,4
IG,0.306092,0.34199,0.082422,-0.269495,-0.001479,4,0.626123,4,0.666667,3
SHAP,0.041938,-0.097407,-0.067262,0.793392,0.645066,1,-0.019568,1,0.166667,1
LIME,0.191277,-0.070717,0.269628,0.290803,0.586374,2,-0.011019,2,0.333333,2


## Threshold

In [54]:
thresholds = [0.05, 0.1, 0.2, 0.3]
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, based_on = "th")
style_df

tokens,You,are,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r
G,0.159478,0.163659,0.117216,0.341822,0.403235,3,-0.006918,2,0.166667,1
GxI,-0.20149,0.308628,-0.051436,-0.271019,-0.002818,5,0.701342,5,0.833333,4
IG,0.306092,0.34199,0.082422,-0.269495,0.00228,4,0.546013,4,0.666667,3
SHAP,0.041938,-0.097407,-0.067262,0.793392,0.517883,1,-0.013836,1,0.166667,1
LIME,0.191277,-0.070717,0.269628,0.290803,0.489518,2,0.072651,3,0.333333,2


## Top k 

In [55]:
thresholds = [1,2,3,4]
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, based_on = "k", only_pos = False, style_df=True)
style_df

tokens,You,are,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r
G,0.159478,0.163659,0.117216,0.341822,0.472835,3,-0.00558,3,0.166667,1
GxI,-0.20149,0.308628,-0.051436,-0.271019,0.058035,5,0.508075,5,0.833333,4
IG,0.306092,0.34199,0.082422,-0.269495,0.069896,4,0.470725,4,0.666667,3
SHAP,0.041938,-0.097407,-0.067262,0.793392,0.550037,1,-0.01236,1,0.166667,1
LIME,0.191277,-0.070717,0.269628,0.290803,0.5158,2,-0.009017,2,0.333333,2
