## Load

In [7]:
import sys
sys.path.append("../src/utils")
from load_repositories import *
import load_repositories
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
import numpy as np
import pandas as pd
from nltk.tokenize import TweetTokenizer
import random

In [20]:
import importlib
importlib.reload(load_repositories)

<module 'load_repositories' from 'C:\\Users\\eliag\\Documents\\Msc\\Thesis\\text_explainability\\notebooks\\../src/utils\\load_repositories.py'>

Load the BERT classification model that was finetuned on tweets with 3 sentiment labels - positive, neutral and negative

In [2]:
pretrained_model = "../src/models/bert_tweets"
tokenizer = BertTokenizer.from_pretrained(pretrained_model, do_lower_case=False)
model = BertForSequenceClassification.from_pretrained(pretrained_model)
model.eval()
label_names = ["Positive", "Neutral", "Negative"]

Set prediction pipeline for the explainer

In [3]:
import transformers

pred = transformers.pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    top_k=None
)

Load ERASER movies dataset

In [23]:
import os
from rationale_benchmark.utils import load_documents, load_datasets, annotations_from_jsonl, Annotation

data_root = os.path.join('..','external_repos','eraserbenchmark-master','data', 'movies')
documents = load_documents(data_root)
val = annotations_from_jsonl(os.path.join(data_root, 'val.jsonl'))
## Or load everything:
train, val, test = load_datasets(data_root)

An example of doc

In [24]:
i = 0
annotation = train[i]
evidences = annotation.all_evidences()
(docid,) = set(ev.docid for ev in evidences)
doc = documents[docid]
print(f"docid: {docid}, Number of sentences: {len(doc)}, number of evidences: {len(evidences)}")
print()

sentences = []
for sent in doc:
    sentence = ' '.join(sent)
    sentences.append(sentence)
    # print(sentence)

doc_text = ' '.join(sentences)

docid: negR_000.txt, Number of sentences: 44, number of evidences: 16



In [25]:
doc_text1 = ' '.join(sentences[:22])
doc_text2 = ' '.join(sentences[22:])

In [26]:
tokens1 = tokenizer(doc_text1)
len(tokens1['input_ids'])

504

In [27]:
tokens2 = tokenizer(doc_text2)
len(tokens2['input_ids'])

420

Explain

In [28]:
import shap

explainer = shap.Explainer(pred)

In [29]:
shap_values = explainer([doc_text1])

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [06:07, 367.56s/it]                                                                  


In [12]:
import pickle 

print(shap_values.shape)

file_pi = open('doc1_chunk1.pkl', 'wb') 
pickle.dump(shap_values, file_pi)

(1, 504, 3)


In [18]:
shap_values.values[0,:,2].max()

0.030944833316271455

In [22]:
shap_vals = shap_values.values[0,:,2]
shap_vals.shape

(504,)

In [54]:
shap_values

.values =
array([[[ 1.41353417e-04, -1.32514906e-03,  1.18379577e-03],
        [ 1.41353417e-04, -1.32514906e-03,  1.18379577e-03],
        [ 1.41353417e-04, -1.32514906e-03,  1.18379577e-03],
        ...,
        [ 7.52904197e-04,  2.26798377e-03, -3.02088793e-03],
        [ 7.52904197e-04,  2.26798377e-03, -3.02088793e-03],
        [ 2.44094050e-04, -6.57664612e-05, -1.78327967e-04]]])

.base_values =
array([[0.00383232, 0.98607022, 0.01009738]])

.data =
(array(['', ' plot', ' :', ' two', ' teen', ' couples', ' go', ' to', ' a',
       ' church', ' party', ' ,', ' drink', ' and', ' then', ' drive',
       ' .', ' they', ' get', ' into', ' an', ' accident', ' .', ' one',
       ' of', ' the', ' guys', ' dies', ' ,', ' but', ' his',
       ' girlfriend', ' continues', ' to', ' see', ' him', ' in', ' her',
       ' life', ' ,', ' and', ' has', ' nightmares', ' .', ' what', " '",
       ' s', ' the', ' deal', ' ?', ' watch', ' the', ' movie', ' and',
       ' "', ' sort', 'a', ' "', ' f

In [185]:
shap_values.values.sum(axis=1)

array([[ 0.02468321, -0.86112129,  0.83643818]])

In [192]:
import numpy as np
tokens = tokenizer(doc_text1, return_tensors='pt', padding=True, truncation=True)
output = model(**tokens)
output

SequenceClassifierOutput(loss=None, logits=tensor([[-1.7120, -0.2345,  1.6787]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [194]:
from scipy.special import softmax
softmax(output.logits.detach())

array([[0.02851553, 0.12494893, 0.84653556]], dtype=float32)

In [196]:
0.84653556 - 0.83643818

0.010097379999999934

In [133]:
threshold = 4*1/504
print(threshold)

0.007936507936507936


Intersection over union (usually for discrete cases)

In [134]:
# How many tokens have shap values that are higher than the threshold
explaining_tokens_num = (shap_vals >= threshold).sum()
explaining_tokens_index = np.where(shap_vals >= threshold)
explaining_tokens = shap_values.data[0][explaining_tokens_index]
explaining_tokens_map = np.array(tokens_map)[explaining_tokens_index]
print(len(explaining_tokens))

52


In [135]:
# How many tokens marked as explanations
annotated_tokens = []

for evidence in evidences:
    if evidence.start_token < tokens_map[n-2]:
        if evidence.end_token > tokens_map[n-2]:
            end_token = tokens_map[n-2]
        else:
            end_token = evidence.end_token
        
        annotated_tokens += list(range(evidence.start_token,end_token))

        print(f"{evidence.start_token}, {end_token}")

print(len(annotated_tokens))

309, 318
83, 87
212, 219
196, 203
368, 370
43, 48
162, 165
182, 187
371, 375
63, 67
143, 150
57


In [136]:
# How many intersect
explaining_tokens_map = list(explaining_tokens_map)
tokens_intersection = np.intersect1d(explaining_tokens_map, annotated_tokens)
intersection = tokens_intersection.size
intersection

17

In [137]:
# Union
tokens_union = np.union1d(explaining_tokens_map, annotated_tokens)
union = tokens_union.size
union

90

In [138]:
iou = intersection / union
iou

0.18888888888888888

In [160]:
text_marked_with_pred_exp = ""
prev_was_exp = False

for i in range(len(shap_values.data[0])): 
    curr_exp = i in explaining_tokens_index[0]

    if (curr_exp and not prev_was_exp):
        text_marked_with_pred_exp += "<<**" 
    elif (not curr_exp and prev_was_exp):
        text_marked_with_pred_exp += "**>>"
    
    text_marked_with_pred_exp += shap_values.data[0][i]
    prev_was_exp = curr_exp

print(text_marked_with_pred_exp)

 plot : two teen couples go to a church party , drink and then drive . they get into an accident . one of the guys dies , but his girlfriend continues to see him in her life , and has nightmares . what ' s the deal ? watch the movie and " sorta " find out . . . critique : a mind - fuck movie for the teen generation that touches on a very cool idea , but presents it in a very bad package . which is what makes this review an even harder one to write , since i generally applaud films which attempt to break the mold , mess with your head and such ( lost highway & memento ) , but there are good and bad ways of making all types of films , and these folks just did n ' t snag this one correctly . they seem to have taken this pretty neat concept ,<<** but executed it terribly .**>> so what are the problems with the movie ?<<** well , its main problem is that it ' s simply too jumbled .**>> it starts off " normal " but then downshifts into this " fantasy " world in which you , as an audience mem

In [179]:
evidences_tokens = []
len_tokens = len(tokens_map_reversed)

for ev in evidences:
    if ev.start_token <= len_tokens:
        start_token = tokens_map_reversed[ev.start_token][0]

        if evidence.end_token <= len_tokens:
            end_token = tokens_map_reversed[ev.end_token][1]
        else:
            end_token = tokens_map_reversed[-1][1]
            
        
        evidences_tokens += list(range(start_token, end_token+1))
        
evidences_tokens

[334,
 335,
 336,
 337,
 338,
 339,
 340,
 341,
 342,
 343,
 86,
 87,
 88,
 89,
 90,
 227,
 228,
 229,
 230,
 231,
 232,
 233,
 234,
 235,
 208,
 209,
 210,
 211,
 212,
 213,
 214,
 215,
 216,
 217,
 218,
 395,
 396,
 397,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 172,
 173,
 174,
 175,
 192,
 193,
 194,
 195,
 196,
 197,
 198,
 199,
 398,
 399,
 400,
 401,
 402,
 66,
 67,
 68,
 69,
 70,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160]

In [204]:
shap_values.values[0,:,2].size

504

In [180]:
text_marked_with_pred_exp2 = ""
prev_was_exp = False

for i in range(len(shap_values.data[0])): 
    curr_exp = i in evidences_tokens

    if (curr_exp and not prev_was_exp):
        text_marked_with_pred_exp2 += "<<**" 
    elif (not curr_exp and prev_was_exp):
        text_marked_with_pred_exp2 += "**>>"
    
    text_marked_with_pred_exp2 += shap_values.data[0][i]
    prev_was_exp = curr_exp

print(text_marked_with_pred_exp2)

 plot : two teen couples go to a church party , drink and then drive . they get into an accident . one of the guys dies , but his girlfriend continues to see him in her life , and has nightmares .<<** what ' s the deal ? watch**>> the movie and " sorta " find out . . . critique : a<<** mind - fuck movie for**>> the teen generation that touches on a very cool idea , but presents it in<<** a very bad package .**>> which is what makes this review an even harder one to write , since i generally applaud films which attempt to break the mold , mess with your head and such ( lost highway & memento ) , but there are good and bad ways of making all types of films , and these folks<<** just did n ' t snag this one correctly .**>> they seem to have taken this pretty neat concept , but<<** executed it terribly .**>> so what are the problems with the movie ? well , its main problem is that<<** it ' s simply too jumbled .**>> it starts off " normal " but then<<** downshifts into this " fantasy " wor

In [181]:
text_marked_with_pred_exp3 = ""
prev_was_exp = False

for i in range(len(shap_values.data[0])): 
    curr_exp = (i in evidences_tokens) and (i in explaining_tokens_index[0])

    if (curr_exp and not prev_was_exp):
        text_marked_with_pred_exp3 += "<<**" 
    elif (not curr_exp and prev_was_exp):
        text_marked_with_pred_exp3 += "**>>"
    
    text_marked_with_pred_exp3 += shap_values.data[0][i]
    prev_was_exp = curr_exp

print(text_marked_with_pred_exp3)

 plot : two teen couples go to a church party , drink and then drive . they get into an accident . one of the guys dies , but his girlfriend continues to see him in her life , and has nightmares . what ' s the deal ? watch the movie and " sorta " find out . . . critique : a mind - fuck movie for the teen generation that touches on a very cool idea , but presents it in a very bad package . which is what makes this review an even harder one to write , since i generally applaud films which attempt to break the mold , mess with your head and such ( lost highway & memento ) , but there are good and bad ways of making all types of films , and these folks just did n ' t snag this one correctly . they seem to have taken this pretty neat concept , but<<** executed it terribly .**>> so what are the problems with the movie ? well , its main problem is that<<** it ' s simply too jumbled .**>> it starts off " normal " but then downshifts into this " fantasy " world in which you , as an audience mem

In [177]:
tokens_map_reversed[83:87]

[[86, 86], [87, 87], [88, 88], [89, 89]]

In [178]:
shap_values.data[0][86:90]

array([' a', ' very', ' bad', ' package'], dtype=object)

In [182]:
evidence

Evidence(text="just did n't snag this one correctly", docid='negR_000.txt', start_token=143, end_token=150, start_sentence=6, end_sentence=7)

In [172]:
a = 0
if a:
    print(1)

In [173]:
tokens_map_reversed = []

for i in range(len(tokens_map)):
    token_ind = tokens_map[i]

    if token_ind is not None and token_ind != tokens_map[i-1]:
        tokens_map_reversed.append([i,i])
        last_token_ind = i
    elif token_ind is not None:
        tokens_map_reversed[-1][1] = i
        
tokens_map_reversed

[[1, 1],
 [2, 2],
 [3, 3],
 [4, 4],
 [5, 5],
 [6, 6],
 [7, 7],
 [8, 8],
 [9, 9],
 [10, 10],
 [11, 11],
 [12, 12],
 [13, 13],
 [14, 14],
 [15, 15],
 [16, 16],
 [17, 17],
 [18, 18],
 [19, 19],
 [20, 20],
 [21, 21],
 [22, 22],
 [23, 23],
 [24, 24],
 [25, 25],
 [26, 26],
 [27, 27],
 [28, 28],
 [29, 29],
 [30, 30],
 [31, 31],
 [32, 32],
 [33, 33],
 [34, 34],
 [35, 35],
 [36, 36],
 [37, 37],
 [38, 38],
 [39, 39],
 [40, 40],
 [41, 41],
 [42, 42],
 [43, 43],
 [44, 44],
 [45, 46],
 [47, 47],
 [48, 48],
 [49, 49],
 [50, 50],
 [51, 51],
 [52, 52],
 [53, 53],
 [54, 54],
 [55, 56],
 [57, 57],
 [58, 58],
 [59, 59],
 [60, 60],
 [61, 61],
 [62, 62],
 [63, 63],
 [64, 64],
 [65, 65],
 [66, 66],
 [67, 67],
 [68, 68],
 [69, 69],
 [70, 70],
 [71, 71],
 [72, 72],
 [73, 73],
 [74, 74],
 [75, 75],
 [76, 76],
 [77, 77],
 [78, 78],
 [79, 79],
 [80, 80],
 [81, 81],
 [82, 82],
 [83, 83],
 [84, 84],
 [85, 85],
 [86, 86],
 [87, 87],
 [88, 88],
 [89, 89],
 [90, 90],
 [91, 91],
 [92, 92],
 [93, 93],
 [94, 94],
 [95, 

In [111]:
np.array(tokens_map)[explaining_tokens_index]

array([151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163,
       164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176,
       177, 178, 179, 180, 181, 182, 183, 183, 184, 185, 186, 186, 187,
       293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305,
       306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318,
       319, 320, 321, 322, 323, 323, 324, 325, 326, 433, 434, 435, 436,
       437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448],
      dtype=object)

In [114]:
list(range(2,5))

[2, 3, 4]

In [36]:
shap_values.data[0][np.where(shap_vals >= 0.02)]

array([' its', ' main', ' problem', ' is', ' that', ' it', " '", ' s',
       ' i', ' get', ' kind', ' of'], dtype=object)

In [41]:
evidences[0]

Evidence(text='i get kind of fed up after a while', docid='negR_000.txt', start_token=309, end_token=318, start_sentence=16, end_sentence=17)

In [45]:
shap_values.data[0][evidences[0].start_token:evidences[0].end_token]

array(['rave', 'l', ' a', ' film', ' every', ' now', ' and', ' then',
       ' ,'], dtype=object)

In [51]:
shap_values.data[0][290:340]

array([' most', ' of', ' it', ' is', ' simply', ' not', ' explained',
       ' .', ' now', ' i', ' personally', ' do', ' n', " '", ' t',
       ' mind', ' trying', ' to', ' un', 'rave', 'l', ' a', ' film',
       ' every', ' now', ' and', ' then', ' ,', ' but', ' when', ' all',
       ' it', ' does', ' is', ' give', ' me', ' the', ' same', ' clue',
       ' over', ' and', ' over', ' again', ' ,', ' i', ' get', ' kind',
       ' of', ' fed', ' up'], dtype=object)

In [55]:
shap_values.data[0][:20]

array(['', ' plot', ' :', ' two', ' teen', ' couples', ' go', ' to', ' a',
       ' church', ' party', ' ,', ' drink', ' and', ' then', ' drive',
       ' .', ' they', ' get', ' into'], dtype=object)

In [59]:
doc_tokens = []

for sent in doc:
    doc_tokens += sent

print(doc_tokens[:10])
len(doc_tokens)

['plot', ':', 'two', 'teen', 'couples', 'go', 'to', 'a', 'church', 'party']


847

In [61]:
"  b  ".strip()

'b'

In [103]:
### Map bert tokens to doc preprocessed tokens
i = 0
j = 0
n = shap_values.data[0].size
m = len(doc_tokens)
tokens_map = []

while i < n and j < m:
    model_token = shap_values.data[0][i].strip()
    
    if model_token == doc_tokens[j]:
        tokens_map.append(j)
        j += 1
    elif model_token == "":
        tokens_map.append(None)
    else:
        found = False
        
        for k in range(1,5):
            if (i + k < n):
                merged_token = model_token

                for l in range(1,k+1):
                    merged_token += shap_values.data[0][i+l].strip()
                
                if merged_token == doc_tokens[j]:
                    found = True
                    
                    for l in range(1,k+2):
                        tokens_map.append(j)
                
                    j += 1
                    i += k
                    break
        
    i += 1

print(i)
print(j)
print(len(tokens_map))

504
472
504


In [104]:
for i in range(len(tokens_map)):
    if tokens_map[i]:
        print(f"{i}-{tokens_map[i]}: {shap_values.data[0][i]} - {doc_tokens[tokens_map[i]]}")

2-1:  : - :
3-2:  two - two
4-3:  teen - teen
5-4:  couples - couples
6-5:  go - go
7-6:  to - to
8-7:  a - a
9-8:  church - church
10-9:  party - party
11-10:  , - ,
12-11:  drink - drink
13-12:  and - and
14-13:  then - then
15-14:  drive - drive
16-15:  . - .
17-16:  they - they
18-17:  get - get
19-18:  into - into
20-19:  an - an
21-20:  accident - accident
22-21:  . - .
23-22:  one - one
24-23:  of - of
25-24:  the - the
26-25:  guys - guys
27-26:  dies - dies
28-27:  , - ,
29-28:  but - but
30-29:  his - his
31-30:  girlfriend - girlfriend
32-31:  continues - continues
33-32:  to - to
34-33:  see - see
35-34:  him - him
36-35:  in - in
37-36:  her - her
38-37:  life - life
39-38:  , - ,
40-39:  and - and
41-40:  has - has
42-41:  nightmares - nightmares
43-42:  . - .
44-43:  what - what
45-44:  ' - 's
46-44:  s - 's
47-45:  the - the
48-46:  deal - deal
49-47:  ? - ?
50-48:  watch - watch
51-49:  the - the
52-50:  movie - movie
53-51:  and - and
54-52:  " - "
55-53:  sort - sort

In [197]:
tokens_map

[None,
 0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 104,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 145,
 145,
 146,
 146,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 

In [105]:
doc_tokens[471:475]

['.', 'i', 'mean', ',']

In [106]:
shap_values.data[0][502:505]

array([' .', ''], dtype=object)

In [56]:
doc[0]

['plot',
 ':',
 'two',
 'teen',
 'couples',
 'go',
 'to',
 'a',
 'church',
 'party',
 ',',
 'drink',
 'and',
 'then',
 'drive',
 '.']

In [14]:
shap.plots.text(shap_values)

In [15]:
shap_values

.values =
array([[[ 1.41353417e-04, -1.32514906e-03,  1.18379577e-03],
        [ 1.41353417e-04, -1.32514906e-03,  1.18379577e-03],
        [ 1.41353417e-04, -1.32514906e-03,  1.18379577e-03],
        ...,
        [ 7.52904197e-04,  2.26798377e-03, -3.02088793e-03],
        [ 7.52904197e-04,  2.26798377e-03, -3.02088793e-03],
        [ 2.44094050e-04, -6.57664612e-05, -1.78327967e-04]]])

.base_values =
array([[0.00383232, 0.98607022, 0.01009738]])

.data =
(array(['', ' plot', ' :', ' two', ' teen', ' couples', ' go', ' to', ' a',
       ' church', ' party', ' ,', ' drink', ' and', ' then', ' drive',
       ' .', ' they', ' get', ' into', ' an', ' accident', ' .', ' one',
       ' of', ' the', ' guys', ' dies', ' ,', ' but', ' his',
       ' girlfriend', ' continues', ' to', ' see', ' him', ' in', ' her',
       ' life', ' ,', ' and', ' has', ' nightmares', ' .', ' what', " '",
       ' s', ' the', ' deal', ' ?', ' watch', ' the', ' movie', ' and',
       ' "', ' sort', 'a', ' "', ' f

In [20]:
tokens = tokenizer(doc_text)
len(tokens['input_ids'])

922

In [23]:
len(doc_text)

4029

In [25]:
t3 = "The concert was amazing, and I had such a great time!"
shap_values2 = explainer([t3])
shap.plots.text(shap_values2)

  0%|          | 0/210 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:17, 17.42s/it]                                                                   


In [26]:
shap_values3 = explainer([doc_text])
shap.plots.text(shap_values3)

RuntimeError: The size of tensor a (922) must match the size of tensor b (512) at non-singleton dimension 1

In [53]:
doc[0]

['plot',
 ':',
 'two',
 'teen',
 'couples',
 'go',
 'to',
 'a',
 'church',
 'party',
 ',',
 'drink',
 'and',
 'then',
 'drive',
 '.']

Score AUPRC

In [1]:
threshold = 0.1