# Interpretation of BertForSequenceClassification in captum

In this notebook we use Captum to interpret a BERT sentiment classifier finetuned on the imdb dataset https://huggingface.co/lvwerra/bert-imdb 

In [1]:
import captum

In [2]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
In /home/fatma/.local/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /home/fatma/.local/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
 print('We will use the GPU:', torch.cuda.get_device_name(0))

We will use the GPU: GeForce RTX 2080 with Max-Q Design


In [6]:
# Get model and config files from https://huggingface.co/lvwerra/bert-imdb
!wget -P ../trained_models/imdb-finetuned-bert/ https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json
!wget -P ../trained_models/imdb-finetuned-bert/ https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin
!wget -P ../trained_models/imdb-finetuned-bert/ https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json
!wget -P ../trained_models/imdb-finetuned-bert/ https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json
!wget -P ../trained_models/imdb-finetuned-bert/ https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin
!wget -P ../trained_models/imdb-finetuned-bert/ https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt

--2020-07-30 17:10:52--  https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.12.198
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.12.198|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 705 [application/json]
Saving to: ‘../trained_models/imdb-finetuned-bert/config.json’


2020-07-30 17:10:53 (13.1 MB/s) - ‘../trained_models/imdb-finetuned-bert/config.json’ saved [705/705]

--2020-07-30 17:10:53--  https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.12.198
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.12.198|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1334420863 (1.2G) [application/octet-stream]
Saving to: ‘../trained_models/imdb-finetuned-bert/pytorch_model.bin’


2020-07-30 17:17:34 (3.18 MB/s) - ‘../trained_models/imdb-finetuned-ber

In [7]:
# load model
model = BertForSequenceClassification.from_pretrained('../trained_models/imdb-finetuned-bert/')
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('../trained_models/imdb-finetuned-bert/')

In [8]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1

In [9]:
def predict(inputs):
    #print('model(inputs): ', model(inputs))
    return model(inputs)[0]

In [10]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [11]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [12]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[:, 0] # for negative attribution, torch.softmax(preds, dim = 1)[:, 1] <- for positive attribution

In [13]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [15]:
# One can test a couple of examples and check that the sentiment classifier is behaving
text =  "The first movie is horrible and bad" #"The movie was one of those amazing movies"#"The movie was one of those amazing movies you can not forget"
#text = "The movie was one of those crappy movies you can't forget."

In [16]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [18]:
#saved_act = None
def save_act(module, inp, out):
  #global saved_act
  #saved_act = out
    return saved_act

hook = model.bert.embeddings.register_forward_hook(save_act)

In [19]:
hook.remove()

In [20]:
# Check predict output
prediction = custom_forward(torch.cat([input_ids]))
print(prediction)

tensor([0.9911], device='cuda:0', grad_fn=<SelectBackward>)


In [21]:
input_ids.shape

torch.Size([1, 9])

In [22]:
pred = predict(input_ids)
torch.softmax(pred, dim = 1)


tensor([[0.9911, 0.0089]], device='cuda:0', grad_fn=<SoftmaxBackward>)

In [23]:
# Check output of custom_forward
custom_forward(input_ids)

tensor([0.9911], device='cuda:0', grad_fn=<SelectBackward>)

In [None]:
#attributions_main, delta_main = lig.attribute(inputs=input_ids,
 #                                   baselines=ref_input_ids,
  #                                  n_steps=7000,
   #                                 internal_batch_size=3,
    #                                return_convergence_delta=True)

In [25]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=7000,
                                    internal_batch_size=5,
                                    return_convergence_delta=True)

In [24]:
input_ids

tensor([[ 101, 1109, 1148, 2523, 1110, 9210, 1105, 2213,  102]],
       device='cuda:0')

In [49]:
tokenized_sen = tokenizer.tokenize(text)
print('Tokenized: ', tokenizer.tokenize(text))

Tokenized:  ['The', 'first', 'movie', 'is', 'horrible', 'and', 'bad']


In [37]:
attributions.shape

torch.Size([1, 9, 1024])

In [42]:
sum(attributions[0][0])

tensor(0., device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)

In [54]:
sum(attributions[0][1])

tensor(-0.7214, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)

In [55]:
sum(attributions[0][7])

tensor(0.9627, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)

In [41]:
sum(attributions[0][8])

tensor(0., device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)

In [53]:
for i in tokenized_sen:
    word = i
    print(word)
    index = tokenized_sen.index(i)+1
    attribution = sum(attributions[0][index])
    print(attribution)

The
tensor(-0.7214, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
first
tensor(0.0981, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
movie
tensor(0.4392, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
is
tensor(-0.2990, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
horrible
tensor(0.0931, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
and
tensor(-0.2201, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
bad
tensor(0.9627, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)


In [35]:
delta

tensor([0.0192], device='cuda:0', dtype=torch.float64)

In [None]:
torch.sum(attributions_main), torch.sum(attributions)

(tensor(0.5967, device='cuda:0', dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(0.5967, device='cuda:0', dtype=torch.float64, grad_fn=<SumBackward0>))

In [None]:
delta, delta_main

(tensor([-0.0206], device='cuda:0', dtype=torch.float64),
 tensor([-0.0206], device='cuda:0', dtype=torch.float64))

In [None]:
torch.argmax(score[0]).cpu().numpy()

array(0)

In [None]:
torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()

array(0.00071609, dtype=float32)

In [26]:
score = predict(input_ids)

print('Sentence: ', text)
print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \
      ', Probability positive: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))

Sentence:  The first movie is horrible and bad
Sentiment: 0, Probability positive: 0.008861683


In [27]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [28]:
attributions_sum = summarize_attributions(attributions)

In [31]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][0],
                                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                                        0,
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)


In [32]:
print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.99),The first movie is horrible and bad,0.26,[CLS] The first movie is horrible and bad [SEP]
,,,,


In [33]:
torch.argmax(torch.softmax(score, dim = 1)[0])

tensor(0, device='cuda:0')

In [34]:
score

tensor([[ 1.9403, -2.7768]], device='cuda:0', grad_fn=<AddmmBackward>)