In [1]:
import os
from transformers import BertTokenizer, AutoTokenizer
from transexp_orig.ExplanationGenerator import Generator
from transexp_orig.BertForSequenceClassification import BertForSequenceClassification

from captum.attr import (
    visualization
)
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import transformers
transformers.__version__

'4.23.1'

In [3]:
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2").to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
# initialize the explanations generator
explanations = Generator(model)

classifications = ["NEGATIVE", "POSITIVE"]


In [4]:
# encode a sentence
text_batch = ["This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to("cuda")
attention_mask = encoding['attention_mask'].to("cuda")

# true class is positive - 1
true_class = 1

# generate an explanation for the input
expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
# get class name
class_name = classifications[classification]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
  expl *= (-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)



[('[CLS]', 0.0), ('this', 0.4254930019378662), ('movie', 0.30647191405296326), ('was', 0.26705053448677063), ('the', 0.31505000591278076), ('best', 0.6269277930259705), ('movie', 0.28255173563957214), ('i', 0.1865611970424652), ('have', 0.10077141225337982), ('ever', 0.14222446084022522), ('seen', 0.18850412964820862), ('!', 0.5975315570831299), ('some', 0.0038841008208692074), ('scenes', 0.033900078386068344), ('were', 0.017799649387598038), ('ridiculous', 0.020366042852401733), (',', 0.0), ('but', 0.42744413018226624), ('acting', 0.4414099156856537), ('was', 0.4980330765247345), ('great', 1.0), ('.', 0.016455795615911484), ('[SEP]', 0.08527109771966934)]


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),1.0,1.0,"[CLS] this movie was the best movie i have ever seen ! some scenes were ridiculous , but acting was great . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),1.0,1.0,"[CLS] this movie was the best movie i have ever seen ! some scenes were ridiculous , but acting was great . [SEP]"
,,,,
