**The coloring of the word importance will be lost after reopening the notebook. Save images before closing.**

In [52]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt
import captum
from tqdm.notebook import tqdm, trange

In [22]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [24]:
# load model
model = BertForSequenceClassification.from_pretrained('./data/model')
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('./data/model')

In [25]:
def predict(inputs):
    #print('model(inputs): ', model(inputs))
    return model(inputs)[0]

In [26]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [27]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [28]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[:, 0] # for negative attribution, torch.softmax(preds, dim = 1)[:, 1] <- for positive attribution

In [29]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [55]:
# One can test a couple of examples and check that the sentiment classifier is behaving
text =  "The first movie is great but the second is horrible and bad" #"The movie was one of those amazing movies"#"The movie was one of those amazing movies you can not forget"
#text = "The movie was one of those crappy movies you can't forget."

In [56]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [57]:
#saved_act = None
def save_act(module, inp, out):
    #global saved_act
    #saved_act = out
    return saved_act

hook = model.bert.embeddings.register_forward_hook(save_act)

In [58]:
hook.remove()

In [59]:
# Check predict output
custom_forward(torch.cat([input_ids]))
input_ids.shape

torch.Size([1, 14])

In [60]:
pred = predict(input_ids)
torch.softmax(pred, dim = 1)


tensor([[9.9928e-01, 7.1609e-04]], device='cuda:0', grad_fn=<SoftmaxBackward>)

In [61]:
# Check output of custom_forward
custom_forward(input_ids)

tensor([0.9993], device='cuda:0', grad_fn=<SelectBackward>)

In [62]:
attributions_main, delta_main = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=500,
                                    internal_batch_size=3,
                                    return_convergence_delta=True)

In [63]:
score = predict(input_ids)

In [64]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [65]:
attributions_sum = summarize_attributions(attributions_main)

In [66]:
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][0],
                                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                                        1,
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta_main)

In [67]:
viz.visualize_text([score_vis])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (1.00),The first movie is great but the second is horrible and bad,0.46,[CLS] The first movie is great but the second is horrible and bad [SEP]
,,,,


In [49]:
viz.visualize_text([score_vis])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (1.00),The first movie is great but the second is horrible and bad,0.39,[CLS] The first movie is great but the second is horrible and bad [SEP]
,,,,


In [86]:
sentences = ["It's a great day.", 
             "It's a great day", 
             "It's a great day!",
             "It's a great day?",
             "It's a great day!!!",
             "It's a GREAT day.",
             "IT'S A great DAY.",
             "IT'S A GREAT DAY.",
             "It's a great day. It's a great day.",
             "It's absolutely a great day.",
             "It's awfully a great day.",
             "It's ABSOLUTELY a great day.",
             "It's not a great day.",
             "It's absolutely not a great day!",
             "It's ABSOLUTELY not a great day.",
             "iT'S nOT a gREaT Day.",
             "It'snotagreatday.",
             "It's not a great day :D",
             "It's not a great day :(",
             "It is not a great day.",
             "'It's not a great day.'",
             "(It's not a great day.)",
             "It's a great day. It's a bad day.",
             "It's a great day. (It's not.)",
             "Here is my number: (850)-100-1000",
             "Here is my number (850)-100-1000",
             "Please check out https://data.tallahassee.com/",
             "Contact me at jz17d@my.fsu.edu",
             "couldnt agree more",
             "couldnt agree more.",
             "could n't agree more",
             "could n't agree more.",
             "Today sucks",
             "Today sux",
             "Kinda sux today! But I'll get by",
             "This restaurant is so good! Couldn't agree more!",
             "Never seen such a bad movie before! Couldn't agree more!",
             "How about meeting tomorrow at 10? Sounds good.",
             "The plan is great! I will talk to you later.",
             "Thoughts on this revision?",
             "Im sorry to hear this happened. Rob and Eric, can you please address this with the tenants?",
             "Don, let me know what our plan will be so that we can market accordingly.",
             ]

In [87]:
score_vis_list = []
for sentence in tqdm(sentences):
    model.zero_grad()
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(sentence, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    pred = predict(input_ids)
    attributions_main, delta_main = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=1000,
                                    internal_batch_size=3,
                                    return_convergence_delta=True)
    score = predict(input_ids)
    attributions_sum = summarize_attributions(attributions_main)
    score_vis = viz.VisualizationDataRecord(attributions_sum,
                                            torch.softmax(score, dim = 1)[0][0],
                                            torch.argmax(torch.softmax(score, dim = 1)[0]),
                                            1,
                                            str(torch.argmax(torch.softmax(score, dim = 1)[0]).cpu().numpy().reshape(1)[0]),
                                            attributions_sum.sum(),       
                                            all_tokens,
                                            delta_main)
    score_vis_list.append(score_vis)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=42.0), HTML(value='')))




In [88]:
viz.visualize_text(score_vis_list)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),1.0,-1.04,[CLS] It ' s a great day . [SEP]
,,,,
1.0,1 (0.00),1.0,-1.34,[CLS] It ' s a great day [SEP]
,,,,
1.0,1 (0.00),1.0,-0.91,[CLS] It ' s a great day ! [SEP]
,,,,
1.0,1 (0.04),1.0,-0.16,[CLS] It ' s a great day ? [SEP]
,,,,
1.0,1 (0.00),1.0,-1.11,[CLS] It ' s a great day ! ! ! [SEP]
,,,,


In [None]:
viz.

In [69]:
viz.visualize_text(score_vis_list)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"tensor(1, device='cuda:0')",-1.04,[CLS] It ' s a great day . [SEP]
,,,,
1.0,1 (0.00),"tensor(1, device='cuda:0')",-1.34,[CLS] It ' s a great day [SEP]
,,,,
1.0,1 (0.00),"tensor(1, device='cuda:0')",-0.91,[CLS] It ' s a great day ! [SEP]
,,,,
1.0,1 (0.04),"tensor(1, device='cuda:0')",-0.16,[CLS] It ' s a great day ? [SEP]
,,,,
1.0,1 (0.00),"tensor(1, device='cuda:0')",-1.11,[CLS] It ' s a great day ! ! ! [SEP]
,,,,


In [71]:
attributions_main.shape

torch.Size([1, 16, 1024])

In [72]:
attributions_sum

tensor([ 0.0000, -0.1252,  0.3052,  0.2426, -0.6051,  0.3492,  0.2767, -0.1513,
         0.1542, -0.0731, -0.3853,  0.1805, -0.0367, -0.1499,  0.1039,  0.0000],
       device='cuda:0', dtype=torch.float64)