In [10]:
%load_ext autoreload
%autoreload 2

In [11]:
from datasets import Dataset, DatasetDict
import numpy as np
import pandas as pd
# import torch
from transformers import (AutoTokenizer, AutoModelForSequenceClassification)
from transformers import BertForSequenceClassification  # Check

# Import tokenizer, model

In [352]:
max_seq_length = 128

In [353]:
tokenizer_name = "bert-base-cased"

In [354]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def preprocess_text(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_seq_length)

In [355]:
model_name = "g8a9/bert-base-cased_ami18"

In [356]:
model = AutoModelForSequenceClassification.from_pretrained(model_name).eval()
effective_model = BertForSequenceClassification.from_pretrained(model_name).eval()

# nlxplain

In [20]:
from nlxplain import Explainer

In [21]:
exp = Explainer(model, tokenizer)

In [22]:
text = "You are a woman"
exp.classify(text)
print(f"Importance occlusion:\n {exp.compute_occlusion_importance(text).numpy()}")
exp.compute_table(text, target = 1)

IDX: You are a woman
Text: You are a woman
Probabilities: tensor([[0.2007, 0.7993]])
Prediction: 1
Importance occlusion:
 [-0.01087821  0.00281805  0.01252252 -0.51788306]


tokens,You,are,a,woman
G,0.159478,0.163659,0.117216,0.341822
GxI,-0.20149,0.308628,-0.051436,-0.271019
IG,0.306092,0.34199,0.082422,-0.269495
SHAP,0.041938,-0.097407,-0.067262,0.793392
LIME,0.191277,-0.070717,0.269628,0.290803


# Evaluate explanations

In [44]:
text = "You are a woman"
explanations = exp.compute_table(text, target = exp.get_predicted_label(text))
tokens = list(explanations.columns)
soft_score_explanation = explanations.loc["SHAP"].values
print(soft_score_explanation)
import seaborn as sns
palette = sns.diverging_palette(240, 10, as_cmap=True)
explanations.style.background_gradient(axis=1, cmap=palette, vmin=-1, vmax=1)

[ 0.04193844 -0.09740739 -0.06726249  0.79339168]


tokens,You,are,a,woman
G,0.159478,0.163659,0.117216,0.341822
GxI,-0.20149,0.308628,-0.051436,-0.271019
IG,0.306092,0.34199,0.082422,-0.269495
SHAP,0.041938,-0.097407,-0.067262,0.793392
LIME,0.191277,-0.070717,0.269628,0.290803


In [43]:
from nlxplain import Evalutator

evalt = Evalutator(exp)
thresholds = [1, 2, 3, 4]
soft_score_explanation = explanations.loc["SHAP"].values
c = evalt.compute_comprehensiveness_ths(text, soft_score_explanation, thresholds, based_on = "k")
s = evalt.compute_sufficiency_ths(text, soft_score_explanation, thresholds, based_on = "k")

# Temporary form
print(c)
print(s)

{1: ([3], 0.5178830623626709), 2: ([3, 0], 0.6923361420631409)}
{1: ([3], -0.013836026191711426), 2: ([3, 0], -0.02977198362350464)}


Compute the 
- **taud_loo**: kendall tau distance between the explanation scores and the occlusion importance (leave one out)
- **AOPC comprehensiveness** (ERASER) -> area under the perturbation curve of the comprehensiveness
- **AOPC sufficiency** (ERASER) -> area under the perturbation curve of the sufficiency


Comprehensiveness: f(x) - f(x\r) where r is the rationale, the important tokens of the explanation. The higher the better

Sufficiency: f(x) - f(r). Close to 0 is better.


For each measure, we also provide the ranking among the explainers.

In the following, we consider the following approaches to extract the "rationale", the most important tokens:
    
- top k: based_on = "k"   --> take only the top k elements, varying k.
- percentage: take top 1%, 5%, 10%, 20% and 50% of tokens as in ERASER, 0, 10, 20, ..., 100% as in Atanasova 
- threshold: take elements greater than a threshold. 

## Top k (positive)

In [53]:
thresholds = [1,2,3,4]
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, based_on = "k", only_pos = True)
style_df

tokens,You,are,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r
G,0.159478,0.163659,0.117216,0.341822,0.472835,3,-0.00558,3,0.166667,1
GxI,-0.20149,0.308628,-0.051436,-0.271019,-0.002818,4,0.701342,5,0.833333,4
IG,0.306092,0.34199,0.082422,-0.269495,-0.003,5,0.627134,4,0.666667,3
SHAP,0.041938,-0.097407,-0.067262,0.793392,0.60511,1,-0.021804,1,0.166667,1
LIME,0.191277,-0.070717,0.269628,0.290803,0.591538,2,-0.010793,2,0.333333,2


## Percentage (positive)

In [56]:
thresholds = np.arange(0, 1.1, 0.1)
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, based_on = "perc", only_pos = False, style_df=True)
style_df

tokens,You,are,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r
G,0.159478,0.163659,0.117216,0.341822,0.52806,3,-0.008643,3,0.166667,1
GxI,-0.20149,0.308628,-0.051436,-0.271019,-0.021811,5,0.678475,5,0.833333,4
IG,0.306092,0.34199,0.082422,-0.269495,-0.001479,4,0.626123,4,0.666667,3
SHAP,0.041938,-0.097407,-0.067262,0.793392,0.645066,1,-0.019568,1,0.166667,1
LIME,0.191277,-0.070717,0.269628,0.290803,0.586374,2,-0.011019,2,0.333333,2


## Threshold

In [54]:
thresholds = [0.05, 0.1, 0.2, 0.3]
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, based_on = "th")
style_df

tokens,You,are,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r
G,0.159478,0.163659,0.117216,0.341822,0.403235,3,-0.006918,2,0.166667,1
GxI,-0.20149,0.308628,-0.051436,-0.271019,-0.002818,5,0.701342,5,0.833333,4
IG,0.306092,0.34199,0.082422,-0.269495,0.00228,4,0.546013,4,0.666667,3
SHAP,0.041938,-0.097407,-0.067262,0.793392,0.517883,1,-0.013836,1,0.166667,1
LIME,0.191277,-0.070717,0.269628,0.290803,0.489518,2,0.072651,3,0.333333,2


## Top k 

In [55]:
thresholds = [1,2,3,4]
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, based_on = "k", only_pos = False, style_df=True)
style_df

tokens,You,are,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r
G,0.159478,0.163659,0.117216,0.341822,0.472835,3,-0.00558,3,0.166667,1
GxI,-0.20149,0.308628,-0.051436,-0.271019,0.058035,5,0.508075,5,0.833333,4
IG,0.306092,0.34199,0.082422,-0.269495,0.069896,4,0.470725,4,0.666667,3
SHAP,0.041938,-0.097407,-0.067262,0.793392,0.550037,1,-0.01236,1,0.166667,1
LIME,0.191277,-0.070717,0.269628,0.290803,0.5158,2,-0.009017,2,0.333333,2


# Plausibility

In [230]:
text = "I am a woman"
exp.classify(text)
print(f"Importance occlusion:\n {exp.compute_occlusion_importance(text).numpy()}")
explanations =  exp.compute_table(text, target = 1)

from nlxplain import Evalutator

evalt = Evalutator(exp)


true_explanation = [0,1,0,1]
thresholds = np.arange(0, 1.1, 0.1)
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, true_rationale=true_explanation, \
                                target=1, based_on = "perc", style_df=True)
style_df

IDX: I am a woman
Text: I am a woman
Probabilities: tensor([[0.2111, 0.7889]])
Prediction: 1
Importance occlusion:
 [-0.00160569  0.00908154  0.00839043 -0.66411984]


tokens,I,am,a,woman,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r,auprc_plau,auprc_plau_r
G,0.128154,0.156558,0.122124,0.358661,0.57876,4,-0.015799,3,0.333333,3,1.0,1
GxI,-0.377682,-0.230217,-0.133846,-0.153122,,1,,5,0.5,4,0.75,3
IG,0.416557,0.246852,0.078772,-0.257818,-0.007678,5,0.680875,4,0.666667,5,0.333333,5
SHAP,-0.032996,-0.142473,-0.046024,0.778507,0.66412,3,-0.024285,1,0.0,1,0.875,2
LIME,0.120084,0.013015,0.224611,0.357594,0.700275,2,-0.018006,2,0.166667,2,0.708333,4


# HateXplain

In [231]:
from datasets import load_dataset

dataset = load_dataset("hatexplain")

Reusing dataset hatexplain (/Users/eliana/.cache/huggingface/datasets/hatexplain/plain_text/1.0.0/df474d8d8667d89ef30649bf66e9c856ad8305bef4bc147e8e31cbdf1b8e0249)
100%|████████████████████████████████████████████| 3/3 [00:00<00:00, 811.02it/s]


In [232]:
dataset["train"][1].keys()

dict_keys(['id', 'annotators', 'rationales', 'post_tokens'])

## Model for hate speech

In [358]:
#MODEL = f"cardiffnlp/twitter-roberta-base-hate"

#tokenizer_hate = AutoTokenizer.from_pretrained(MODEL)


# PT
#model_hate = AutoModelForSequenceClassification.from_pretrained(MODEL)
#model.save_pretrained(MODEL)

In [360]:
#exp_hate = Explainer(model_hate, tokenizer_hate)

## Evaluate explanations

In [365]:
from nlxplain import Explainer

In [366]:
exp = Explainer(model, tokenizer)

In [367]:
i = 0
instance_hatexplain = dataset["train"][i]
text = " ".join(instance_hatexplain["post_tokens"])

exp.classify(text)
target = exp.get_predicted_label(text)


print(f"Importance occlusion:\n {exp.compute_occlusion_importance(text).numpy()}")
explanations =  exp.compute_table(text, target = target)

IDX: u really think i would not have been raped by feral hindu or muslim back in india or bangladesh and a neo nazi would rape me as well just to see me cry
Text: u really think i would not have been raped by feral hindu or muslim back in india or bangladesh and a neo nazi would rape me as well just to see me cry
Probabilities: tensor([[0.5420, 0.4580]])
Prediction: 0
Importance occlusion:
 [-0.07536277  0.00406748 -0.05482131 -0.11015677 -0.03581223  0.01808208
  0.00583032  0.01801482 -0.0319446  -0.00952613 -0.0555388  -0.01674065
  0.03008634  0.01789537 -0.00721917  0.0029231   0.00133276 -0.01736441
  0.00839406  0.00839406  0.01580939  0.01730666  0.01290727  0.02179924
  0.01306814  0.01662275 -0.00572476  0.0292125   0.00801119 -0.00659114
 -0.00990164 -0.01178586  0.01762742  0.02804026  0.00994155 -0.00234556
  0.01261893  0.00761637 -0.00626582  0.01048899]


Partition explainer: 2it [00:12, 12.64s/it]                                     


In [392]:
i = 0
instance_hatexplain = dataset["train"][i]
text = " ".join(instance_hatexplain["post_tokens"])

from nlxplain import Evalutator

evalt = Evalutator(exp)

thresholds = np.arange(0, 1.1, 0.1)


# As in hateXplain, we consider the union of explanations
rationales = instance_hatexplain["rationales"]
rationale = [any(each) for each in zip(*rationales)]
rationale = [int(each) for each in rationale]

token_rationale = evalt.get_true_rational_tokens(instance_hatexplain["post_tokens"],\
                                           rationale)
df_eval, style_df = evalt.evaluate_explainers(text, explanations, thresholds, \
                                              true_rationale=token_rationale, target=target, based_on = "perc", style_df=True)
style_df

40


Unnamed: 0,u,really,think,i,would,not,have,been,raped,by,feral,hind,##u,or,m,##us,##lim,back,in,in.1,##dia,or.1,bang,##lades,##h,and,a,neo,na,##zi,would.1,rape,me,as,well,just,to,see,me.1,cry,aopc_compr,aopc_compr_r,aopc_suff,aopc_suff_r,taud_loo,taud_loo_r,auprc_plau,auprc_plau_r
G,0.068493,0.028683,0.027015,0.054594,0.02106,0.01414,0.011362,0.018912,0.071919,0.015829,0.038559,0.028022,0.021881,0.014568,0.01478,0.020433,0.02135,0.015168,0.00992,0.009045,0.023536,0.014275,0.027766,0.029525,0.018506,0.016857,0.013219,0.033778,0.02178,0.032425,0.016634,0.041361,0.027946,0.012771,0.014528,0.01384,0.010181,0.01399,0.024552,0.023835,-0.318763,5,0.163351,5,0.609686,5,0.538185,1
GxI,-0.286253,-0.042247,-0.070103,0.037369,0.039985,-0.011685,-0.008406,-0.045854,-0.001314,0.000942,-0.018559,0.007225,-0.031434,-0.000333,0.019096,-0.007055,-0.000472,0.0157,-0.000988,0.003608,-0.010007,-0.00531,0.008831,0.031424,0.010558,-0.010994,-0.022572,-0.031043,0.005701,0.02911,-0.02412,0.002432,-0.024098,-0.029018,-0.005053,-0.004828,-0.010921,-0.011336,-0.015765,-0.000828,-0.138426,4,-0.096619,4,0.587877,4,0.324553,3
IG,0.024416,0.026909,-0.004025,0.000314,-0.01212,0.0,0.005098,0.012268,-0.000834,0.017854,-0.00101,0.020547,-0.011599,0.027112,-0.006973,0.029427,0.01189,0.041401,0.021025,0.040133,0.014927,0.027173,0.022177,0.012758,0.013656,0.059657,0.070684,0.003687,0.024443,0.00392,0.020098,-0.00163,0.066752,0.051343,0.035513,0.062851,0.040371,0.054878,0.063144,0.035382,0.013711,3,-0.362835,1,0.437781,3,0.187923,5
SHAP,-0.070314,0.013228,0.011188,0.046025,0.021787,0.021582,0.020979,-0.001719,-0.056267,-0.006244,-0.052505,0.039082,0.04631,0.043207,0.02701,0.02701,0.03083,0.005297,0.005297,0.018138,0.018138,0.019352,0.014161,0.014161,0.014161,0.006795,0.038018,0.011574,0.013987,0.013987,-0.017102,-0.063959,0.018568,0.014522,0.030848,0.030848,0.02695,0.02695,0.0064,0.007085,0.081663,2,-0.357869,2,0.42516,2,0.373919,2
LIME,0.009248,-0.016188,-0.032119,-0.032777,-0.047962,0.054321,0.033781,-0.029572,-0.019393,0.030685,-0.027293,-0.012483,-0.032119,0.0155,-0.031532,0.006044,-0.026048,-0.032777,0.029543,-0.020138,0.031502,-0.029253,-0.030977,0.037573,0.013541,-0.019393,-0.006595,-0.016934,-0.021939,-0.006754,0.010435,-0.04258,-0.006754,-0.004136,0.020505,-0.009114,-0.016934,-0.047375,0.001935,-0.031532,0.116064,1,-0.346928,3,0.404056,1,0.204002,4
