# Interpretation of BertForSequenceClassification in captum

In [45]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

import torch

In [46]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [47]:

# load model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [48]:
def predict(inputs):
    print("[0]", model(inputs)[0])
    print("logits", model(inputs).logits)
    return model(inputs)[0]

In [49]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [50]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [51]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)

In [52]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [53]:
text = "These tests do not work as expected."

In [54]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [55]:
model(input_ids)

SequenceClassifierOutput(loss=None, logits=tensor([[-0.4273, -0.4544]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [56]:
predict(input_ids)

[0] tensor([[-0.4273, -0.4544]], device='cuda:0', grad_fn=<AddmmBackward0>)
logits tensor([[-0.4273, -0.4544]], device='cuda:0', grad_fn=<AddmmBackward0>)


tensor([[-0.4273, -0.4544]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [57]:
custom_forward(input_ids)

[0] tensor([[-0.4273, -0.4544]], device='cuda:0', grad_fn=<AddmmBackward0>)
logits tensor([[-0.4273, -0.4544]], device='cuda:0', grad_fn=<AddmmBackward0>)


tensor([0.5068], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [58]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)

[0] tensor([[-0.4273, -0.4544]], device='cuda:0')
logits tensor([[-0.4273, -0.4544]], device='cuda:0')
[0] tensor([[-0.4605, -0.2021]], device='cuda:0')
logits tensor([[-0.4605, -0.2021]], device='cuda:0')
[0] tensor([[-0.4601, -0.2025],
        [-0.4605, -0.2058],
        [-0.4606, -0.2109],
        [-0.4598, -0.2176],
        [-0.4591, -0.2273],
        [-0.4576, -0.2394],
        [-0.4557, -0.2545],
        [-0.4520, -0.2718],
        [-0.4467, -0.2925],
        [-0.4395, -0.3169],
        [-0.4301, -0.3448],
        [-0.4177, -0.3738],
        [-0.4027, -0.4049],
        [-0.3861, -0.4365],
        [-0.3680, -0.4704],
        [-0.3501, -0.5103],
        [-0.3313, -0.5565],
        [-0.2948, -0.5728],
        [-0.2679, -0.5732],
        [-0.2714, -0.5778],
        [-0.3487, -0.5928],
        [-0.3946, -0.5775],
        [-0.3777, -0.5418],
        [-0.3761, -0.5293],
        [-0.3722, -0.5326],
        [-0.3610, -0.5259],
        [-0.3510, -0.5116],
        [-0.3462, -0.4981],
      

In [61]:
score = predict(input_ids)

print('Question: ', text)
print('Predicted Answer: ' + str(torch.argmax(score[0]).cpu().numpy()) + ', prob ungrammatical: ' + str(torch.softmax(score, dim = 1)[0][0].cpu().detach().numpy()))

[0] tensor([[-0.4273, -0.4544]], device='cuda:0', grad_fn=<AddmmBackward0>)
logits tensor([[-0.4273, -0.4544]], device='cuda:0', grad_fn=<AddmmBackward0>)
Question:  These tests do not work as expected.
Predicted Answer: 0, prob ungrammatical: 0.5067745


In [62]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [63]:
attributions_sum = summarize_attributions(attributions)

In [64]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.softmax(score, dim = 1)[0][0],
                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                        0,
                        text,
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.51),These tests do not work as expected.,2.17,[CLS] these tests do not work as expected . [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.51),These tests do not work as expected.,2.17,[CLS] these tests do not work as expected . [SEP]
,,,,
