In [4]:
import os
import torch
from transformers import BertTokenizer, AutoTokenizer
from transexp_orig.ExplanationGenerator import Generator
from transexp_orig.BertForSequenceClassification import BertForSequenceClassification

from captum.attr import (
    visualization
)

In [5]:
model = BertForSequenceClassification.from_pretrained("roberta-base").to("cuda")
model.eval()

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
explanations = Generator(model)
classifications = ["NEGATIVE", "POSITIVE"]

In [18]:
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2").to("cuda")
model.eval()

tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
explanations = Generator(model)
classifications = ["NEGATIVE", "POSITIVE"]

Downloading: 100%|██████████| 48.0/48.0 [00:00<00:00, 23.2kB/s]
Downloading: 100%|██████████| 232k/232k [00:00<00:00, 296kB/s]  
Downloading: 100%|██████████| 112/112 [00:00<00:00, 44.2kB/s]


In [20]:
# encode a sentence
text_batch = ["This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to("cuda")
attention_mask = encoding['attention_mask'].to("cuda")

# true class is positive - 1
true_class = 1

# generate an explanation for the input
expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
# get class name
class_name = classifications[classification]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
  expl *= (-1)


In [21]:
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten()) #.split(" ")
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
print(len(expl))
print(len(tokens))
# len(tokens)
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)

[('[CLS]', 0.0), ('this', 0.4248310327529907), ('movie', 0.31070131063461304), ('was', 0.26831868290901184), ('the', 0.3402961194515228), ('best', 0.6240742206573486), ('movie', 0.284561425447464), ('i', 0.18455998599529266), ('have', 0.10067689418792725), ('ever', 0.1404617726802826), ('seen', 0.18854057788848877), ('!', 0.5899584293365479), ('some', 0.003955088090151548), ('scenes', 0.03270178660750389), ('were', 0.018634729087352753), ('ridiculous', 0.018232356756925583), (',', 0.0), ('but', 0.42691802978515625), ('acting', 0.4365212023258209), ('was', 0.5009468197822571), ('great', 1.0), ('.', 0.013708234764635563), ('[SEP]', 0.08643466979265213)]
23
23


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),1.0,1.0,"[CLS] this movie was the best movie i have ever seen ! some scenes were ridiculous , but acting was great . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),1.0,1.0,"[CLS] this movie was the best movie i have ever seen ! some scenes were ridiculous , but acting was great . [SEP]"
,,,,


In [17]:
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten()) #.split(" ")
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
print(len(expl))
print(len(tokens))
# len(tokens)
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)

[('<s>', -0.0), ('This', -0.7837290167808533), ('Ġmovie', -0.45044973492622375), ('Ġwas', -0.47404444217681885), ('Ġthe', -0.7911797165870667), ('Ġbest', -0.8549357652664185), ('Ġmovie', -0.7131478786468506), ('ĠI', -0.4093055725097656), ('Ġhave', -0.28496885299682617), ('Ġever', -0.24298979341983795), ('Ġseen', -0.6550716757774353), ('!', -0.20471611618995667), ('Ġsome', -0.4656515419483185), ('Ġscenes', -0.0), ('Ġwere', -0.9210516214370728), ('Ġridiculous', -1.0), (',', -0.5775463581085205), ('Ġbut', -0.8929394483566284), ('Ġacting', -0.8031914830207825), ('Ġwas', -0.40237486362457275), ('Ġgreat', -0.49718427658081055), ('.', -0.4338648021221161), ('</s>', -0.3194558024406433)]
23
23


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.56),1.0,1.0,"#s This Ġmovie Ġwas Ġthe Ġbest Ġmovie ĠI Ġhave Ġever Ġseen ! Ġsome Ġscenes Ġwere Ġridiculous , Ġbut Ġacting Ġwas Ġgreat . #/s"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.56),1.0,1.0,"#s This Ġmovie Ġwas Ġthe Ġbest Ġmovie ĠI Ġhave Ġever Ġseen ! Ġsome Ġscenes Ġwere Ġridiculous , Ġbut Ġacting Ġwas Ġgreat . #/s"
,,,,
