In [1]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

import torch

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
model = BertForSequenceClassification.from_pretrained('textattack/bert-base-uncased-rotten-tomatoes')
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('textattack/bert-base-uncased-rotten-tomatoes')

In [4]:
def predict(inputs):
    return model(inputs)[0]

In [5]:
print(tokenizer.cls_token)

[CLS]


In [6]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

In [7]:
print(cls_token_id)

101


In [8]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text,add_special_tokens=False)
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]
    return torch.tensor([input_ids],device=device), torch.tensor([ref_input_ids],device=device),len(text_ids)

In [9]:
def construct_input_ref_token_type_pair(input_ids,sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([
        0 if i <= sep_ind else 1 for i in range(seq_len)
    ],device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids,device=device)
    return token_type_ids, ref_token_type_ids

In [10]:
def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length,dtype= torch.long, device= device)
    ref_position_ids = torch.zeros(seq_length,dtype=torch.long, device= device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids


In [11]:
input_ids = construct_input_ref_pair("hello there mate", ref_token_id,sep_token_id,cls_token_id)

In [12]:
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [13]:
construct_attention_mask(input_ids[0])

tensor([[1, 1, 1, 1, 1]])

In [57]:
def custom_foward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim= 1)[0][0].unsqueeze(-1)

In [41]:
preds = predict(input_ids[0])


ValueError: Wrong shape for input_ids (shape torch.Size([135])) or attention_mask (shape torch.Size([135]))

In [25]:
torch.softmax(preds,dim=1)[0][0].unsqueeze(-1)

tensor([0.0116], grad_fn=<UnsqueezeBackward0>)

In [26]:
model.bert.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [58]:
lig = LayerIntegratedGradients(custom_foward,model.bert.embeddings)

In [59]:
text = "For those who like their romance movies filled with unnecessary mysteries, murdered dogs, poached lobsters and the ghosts of deceased little girls, “Dirt Music” will fit the bill. All others need not apply, not even if you’re into the kind of Nicholas Sparks-style drama this movie shamelessly marinates in for an interminable 105 minutes. Director Gregor Jordan’s Australia-set potboiler plays like “Wake in Fright” meets “The Notebook”; the toxic masculinity of several characters wreaks havoc before one guy reveals a softer side that bends toward true love as a means of assuaging his guilt."

In [60]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text,ref_token_id,sep_token_id,cls_token_id)

In [61]:
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids,sep_id)

In [62]:
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)

In [63]:
attention_mask = construct_attention_mask(input_ids)

In [64]:
indices = input_ids[0].detach().tolist()

In [65]:
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [66]:
attributions, delta = lig.attribute(inputs = input_ids,
                                   baselines=ref_input_ids,
                                   return_convergence_delta=True)

In [67]:
score = predict(input_ids)

print("Input: ", text)
print('Predicted Answer: ' + str(torch.argmax(score[0])) + ', prob ungrammatical: ' + str(torch.softmax(score, dim = 1)[0][0].detach().numpy()))

Input:  For those who like their romance movies filled with unnecessary mysteries, murdered dogs, poached lobsters and the ghosts of deceased little girls, “Dirt Music” will fit the bill. All others need not apply, not even if you’re into the kind of Nicholas Sparks-style drama this movie shamelessly marinates in for an interminable 105 minutes. Director Gregor Jordan’s Australia-set potboiler plays like “Wake in Fright” meets “The Notebook”; the toxic masculinity of several characters wreaks havoc before one guy reveals a softer side that bends toward true love as a means of assuaging his guilt.
Predicted Answer: tensor(0, grad_fn=<NotImplemented>), prob ungrammatical: 0.9986638


In [68]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [69]:
attributions_sum = summarize_attributions(attributions)

In [55]:


# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.softmax(score, dim = 1)[0][0],
                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                        0,
                        text,
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])



[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"For those who like their romance movies filled with unnecessary mysteries, murdered dogs, poached lobsters and the ghosts of deceased little girls, “Dirt Music” will fit the bill. All others need not apply, not even if you’re into the kind of Nicholas Sparks-style drama this movie shamelessly marinates in for an interminable 105 minutes. Director Gregor Jordan’s Australia-set potboiler plays like “Wake in Fright” meets “The Notebook”; the toxic masculinity of several characters wreaks havoc before one guy reveals a softer side that bends toward true love as a means of assuaging his guilt.",-4.35,"[CLS] for those who like their romance movies filled with unnecessary mysteries , murdered dogs , po ##ache ##d lobster ##s and the ghosts of deceased little girls , “ dirt music ” will fit the bill . all others need not apply , not even if you ’ re into the kind of nicholas sparks - style drama this movie shame ##lessly marina ##tes in for an inter ##mina ##ble 105 minutes . director gregor jordan ’ s australia - set pot ##bo ##ile ##r plays like “ wake in fright ” meets “ the notebook ” ; the toxic mas ##cu ##lini ##ty of several characters wr ##eak ##s havoc before one guy reveals a softer side that bends toward true love as a means of ass ##ua ##ging his guilt . [SEP]"
,,,,


In [70]:
score_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.softmax(score, dim = 1)[0][0],
                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                        0,
                        text,
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"For those who like their romance movies filled with unnecessary mysteries, murdered dogs, poached lobsters and the ghosts of deceased little girls, “Dirt Music” will fit the bill. All others need not apply, not even if you’re into the kind of Nicholas Sparks-style drama this movie shamelessly marinates in for an interminable 105 minutes. Director Gregor Jordan’s Australia-set potboiler plays like “Wake in Fright” meets “The Notebook”; the toxic masculinity of several characters wreaks havoc before one guy reveals a softer side that bends toward true love as a means of assuaging his guilt.",4.35,"[CLS] for those who like their romance movies filled with unnecessary mysteries , murdered dogs , po ##ache ##d lobster ##s and the ghosts of deceased little girls , “ dirt music ” will fit the bill . all others need not apply , not even if you ’ re into the kind of nicholas sparks - style drama this movie shame ##lessly marina ##tes in for an inter ##mina ##ble 105 minutes . director gregor jordan ’ s australia - set pot ##bo ##ile ##r plays like “ wake in fright ” meets “ the notebook ” ; the toxic mas ##cu ##lini ##ty of several characters wr ##eak ##s havoc before one guy reveals a softer side that bends toward true love as a means of ass ##ua ##ging his guilt . [SEP]"
,,,,


In [318]:
print(score_vis.word_attributions)

tensor([ 0.0000,  0.0763,  0.0082, -0.0799,  0.0710,  0.0987, -0.1030, -0.0780,
         0.1009,  0.0898,  0.1629, -0.0741,  0.0237,  0.0729,  0.0888, -0.0469,
        -0.0284, -0.0012,  0.0855,  0.0056,  0.1050,  0.0246,  0.0364, -0.0199,
         0.0371, -0.0080,  0.0265, -0.0262, -0.0210, -0.0089, -0.0070,  0.0152,
        -0.0229, -0.0183,  0.0985,  0.0232,  0.0394,  0.0178,  0.0278,  0.0530,
         0.0763,  0.0592,  0.0215, -0.0305,  0.0329,  0.0100,  0.0807,  0.0008,
         0.0309,  0.0092,  0.0325, -0.0111, -0.0801,  0.0015,  0.0118, -0.0466,
        -0.0553,  0.0533, -0.0592,  0.0430, -0.0170,  0.0562,  0.0270, -0.0194,
         0.0093,  0.0800,  0.0647,  0.0551,  0.0096, -0.0299, -0.0150,  0.0105,
         0.0491,  0.1105,  0.0634,  0.0905,  0.0293,  0.0471, -0.0068, -0.0123,
        -0.0201,  0.1341,  0.0933,  0.0045,  0.0666, -0.0860, -0.0817,  0.1575,
        -0.0215,  0.0672,  0.1110,  0.0154, -0.0006, -0.1384, -0.0006,  0.0761,
         0.0539, -0.0237, -0.0154,  0.10

In [330]:
for word, attribution in zip(all_tokens, score_vis.word_attributions):
    print(word,attribution.detach().numpy())

[CLS] 0.0
for 0.07634265305330053
those 0.00816911598006442
who -0.07988392549505588
like 0.07100970644634169
their 0.09866435316602275
romance -0.10300177410320334
movies -0.07802647303424427
filled 0.1009254898594088
with 0.08976942278816583
unnecessary 0.16285936800348394
mysteries -0.07408248706719925
, 0.02370061885019614
murdered 0.07288591744963933
dogs 0.08878960985980135
, -0.04685117091958494
po -0.028433565907670575
##ache -0.0012373378129953822
##d 0.08552001541728471
lobster 0.00556359579336626
##s 0.10497541475781372
and 0.02456515090507236
the 0.036422598459274286
ghosts -0.019884743965391954
of 0.03707173755903409
deceased -0.008041542993494256
little 0.02652127125497944
girls -0.026232039467154983
, -0.021018027420983437
“ -0.008920087160898068
dirt -0.007040503100642204
music 0.015226563432626803
” -0.022866840764515516
will -0.018272898119638184
fit 0.0985455791588231
the 0.023222161589725558
bill 0.0393524944903111
. 0.01782133236487718
all 0.027805335555030904
othe

In [335]:
ig = IntegratedGradients(model)

In [336]:
ig.attribute(inputs=input_ids, 
            baselines=ref_input_ids,
            return_convergence_delta=True)

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead (while checking arguments for embedding)

In [None]:
attributions, delta = lig.attribute(inputs = input_ids,
                                   baselines=ref_input_ids,
                                   return_convergence_delta=True)