# Interpretation of BertForSequenceClassification in captum

In this notebook we use Captum to interpret a BERT sentiment classifier finetuned on the imdb dataset https://huggingface.co/lvwerra/bert-imdb 

In [1]:
# Install dependencies
# !pip install transformers

# !pip install captum


In [25]:
import transformers
print(transformers.__version__)
import torch
print(torch.__version__)
import torchvision
print(torchvision.__version__)

3.0.2
1.4.0
0.5.0


In [26]:
import captum

In [27]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [28]:
# import captum

In [29]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [30]:
# Get model and config files from https://huggingface.co/lvwerra/bert-imdb
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt

In [31]:
class BERT(nn.Module):

    def __init__(self):
        super(BERT, self).__init__()

        options_name = "bert-base-uncased"
        self.encoder = BertForSequenceClassification.from_pretrained(options_name)

    def forward(self, text, label):
        loss, text_fea = self.encoder(text, labels=label)[:2]

        return loss, text_fea
        
def load_checkpoint(load_path, model):
    
    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    model.load_state_dict(state_dict['model_state_dict'])
    return state_dict['valid_loss']

In [32]:
torch.__version__

'1.4.0'

In [33]:
# load model
# model = BertForSequenceClassification.from_pretrained(r'/Users/andrewmendez1/Documents/ai-ml-challenge-2020/data/model/model.pt')
model = BERT().to(device)

load_checkpoint('/Users/andrewmendez1/Documents/ai-ml-challenge-2020/data/Finetune BERT oversampling 8_16_2020/Model_1_4_0/model.pt', model)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [34]:
model

BERT(
  (encoder): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
    

In [49]:
def predict(inputs):
    #print('model(inputs): ', model(inputs))
    return model.encoder(inputs)[0]

In [50]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [51]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [52]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[:, 0] # for negative attribution, torch.softmax(preds, dim = 1)[:, 1] <- for positive attribution

In [53]:
lig = LayerIntegratedGradients(custom_forward, model.encoder.bert.embeddings)

In [54]:
# One can test a couple of examples and check that the sentiment classifier is behaving
# text =  "The first movie is great but the second is horrible and bad" #"The movie was one of those amazing movies"#"The movie was one of those amazing movies you can not forget"
#text = "The movie was one of those crappy movies you can't forget."
text= "this license shall be effective until company in its sole and absolute at any time and for any or no disable the or suspend or terminate this license and the rights afforded to you with or without prior notice or other action by upon the termination of this you shall cease all use of the app and uninstall the company will not be liable to you or any third party for or damages of any sort as a result of terminating this license in accordance with its and termination of this license will be without prejudice to any other right or remedy company may now or in the these obligations survive termination of this"
label=1

In [55]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [59]:
#saved_act = None
def save_act(module, inp, out):
  #global saved_act
  #saved_act = out
  return saved_act

hook = model.encoder.bert.embeddings.register_forward_hook(save_act)

In [60]:
hook.remove()

In [61]:
# Check predict output
custom_forward(torch.cat([input_ids]))
input_ids.shape

torch.Size([1, 120])

In [62]:
pred = predict(input_ids)
torch.softmax(pred, dim = 1)


tensor([[0.0258, 0.9742]], grad_fn=<SoftmaxBackward>)

In [63]:
# Check output of custom_forward
custom_forward(input_ids)

tensor([0.0258], grad_fn=<SelectBackward>)

In [64]:
input_ids

tensor([[  101,  2023,  6105,  4618,  2022,  4621,  2127,  2194,  1999,  2049,
          7082,  1998,  7619,  2012,  2151,  2051,  1998,  2005,  2151,  2030,
          2053,  4487, 19150,  1996,  2030, 28324,  2030, 20320,  2023,  6105,
          1998,  1996,  2916, 22891,  2000,  2017,  2007,  2030,  2302,  3188,
          5060,  2030,  2060,  2895,  2011,  2588,  1996, 18287,  1997,  2023,
          2017,  4618, 13236,  2035,  2224,  1997,  1996, 10439,  1998,  4895,
          7076,  9080,  2140,  1996,  2194,  2097,  2025,  2022, 20090,  2000,
          2017,  2030,  2151,  2353,  2283,  2005,  2030, 12394,  1997,  2151,
          4066,  2004,  1037,  2765,  1997, 23552,  2023,  6105,  1999, 10388,
          2007,  2049,  1998, 18287,  1997,  2023,  6105,  2097,  2022,  2302,
         18024,  2000,  2151,  2060,  2157,  2030, 19519,  2194,  2089,  2085,
          2030,  1999,  1996,  2122, 14422,  5788, 18287,  1997,  2023,   102]])

In [90]:
attributions_main, delta_main = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    # n_steps=500,
                                    internal_batch_size=3,
                                    return_convergence_delta=True)

In [91]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    # n_steps=500,
                                    internal_batch_size=5,
                                    return_convergence_delta=True)

In [92]:
torch.sum(attributions_main), torch.sum(attributions)

(tensor(-0.8025, grad_fn=<SumBackward0>),
 tensor(-0.8025, grad_fn=<SumBackward0>))

In [93]:
delta, delta_main

(tensor([0.0262]), tensor([0.0262]))

In [94]:
torch.argmax(score[0]).cpu().numpy()

array(1)

In [95]:
torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()

array(0.97422045, dtype=float32)

In [2]:
score = predict(input_ids)

print('Sentence: ', text)
print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \
      ', Probability of Not Acceptable EULA: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))

NameError: name 'predict' is not defined

In [97]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [98]:
attributions_sum = summarize_attributions(attributions)

In [1]:
attributions_sum

NameError: name 'attributions_sum' is not defined

In [99]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][0],
                                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                                        label,
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)


In [100]:
print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.03),this license shall be effective until company in its sole and absolute at any time and for any or no disable the or suspend or terminate this license and the rights afforded to you with or without prior notice or other action by upon the termination of this you shall cease all use of the app and uninstall the company will not be liable to you or any third party for or damages of any sort as a result of terminating this license in accordance with its and termination of this license will be without prejudice to any other right or remedy company may now or in the these obligations survive termination of this,-3.26,[CLS] this license shall be effective until company in its sole and absolute at any time and for any or no di ##sable the or suspend or terminate this license and the rights afforded to you with or without prior notice or other action by upon the termination of this you shall cease all use of the app and un ##ins ##tal ##l the company will not be liable to you or any third party for or damages of any sort as a result of terminating this license in accordance with its and termination of this license will be without prejudice to any other right or remedy company may now or in the these obligations survive termination of this [SEP]
,,,,


In [101]:
torch.argmax(torch.softmax(score, dim = 1)[0])

tensor(1)

In [102]:
score

tensor([[-1.7944,  1.8377]], grad_fn=<AddmmBackward>)