# Interpretation of BertForSequenceClassification in captum

In this notebook we'll see how use Captum's Layer Integrated Gradients method to interpret a BERT sentiment classifier that has been finetuned on the imdb dataset https://huggingface.co/lvwerra/bert-imdb.

### Install dependencies

We'll begin by installing library dependencies, namely the `captum` and `transformers` libraries.

In [1]:
# !pip install --upgrade pip

In [1]:
# !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

In [2]:
# !pip install transformers==4.1.1
# !pip install captum==0.5.0

In [3]:
import captum
import transformers

print(f'Transformers version: {transformers.__version__}')   # 4.1.1
print(f'Captum version : {captum.__version__}')              # 0.5.0

2022-10-31 19:54:10.521998: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.11.0


Transformers version: 4.1.1
Captum version : 0.5.0


### Build and train the model

In [4]:
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz

import matplotlib.pyplot as plt
import torch

from transformers import BertForSequenceClassification, BertTokenizer

This notebook runs faster when enabled with GPU.

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
print(f'device: {device}')

device: cuda:0


## Download BERT model from HuggingFace

We'll use a pretrained BERT model from HuggingFace.

In [7]:
# Get model and config files from https://huggingface.co/lvwerra/bert-imdb
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt

--2022-10-31 19:54:37--  https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.196.48
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.196.48|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 705 [application/json]
Saving to: ‘./model/config.json’


2022-10-31 19:54:38 (53.5 MB/s) - ‘./model/config.json’ saved [705/705]

--2022-10-31 19:54:38--  https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.196.48
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.196.48|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1334420863 (1.2G) [application/octet-stream]
Saving to: ‘./model/pytorch_model.bin’


2022-10-31 19:55:06 (45.3 MB/s) - ‘./model/pytorch_model.bin’ saved [1334420863/1334420863]

--2022-10-31 19:55:07--  https://s3.amazonaws.com/models.huggingface

Next we'll load the model and pre-trained BERT tokenizer.

In [8]:
# Load the model.
model = BertForSequenceClassification.from_pretrained('./model')
model.to(device)
model.eval()
model.zero_grad()

# Load the pretrained tokenizer.
tokenizer = BertTokenizer.from_pretrained('./model')

The pre-trained BERT tokenizer has special tokens that are used when pre-training the BERT model. 
BERT is pre-trained with two tasks: A classification task which predicts if one sentence follows another in the original corpus, and a masked language task which predicts which word was masked from a sentence. 

The BERT tokenizer has special tokens for the "next sentence prediction" task. Namely, we need a way to inform the model where does the first sentence end, and where does the second sentence begin. The `SEP` token is used as a separator added to the end of text, and `CLS` is used for prepending the two sentences. 

The `PAD` token is used to pad sequences to have constant length. 

The 'CLS' token is used to prepend to the concatenated question-text word sequence

In [9]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

In [10]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    
    # Construct input token ids.
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    
    # Construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [11]:
def custom_forward(inputs):
    preds = model(inputs)[0]
    return torch.softmax(preds, dim = 1)[0][1].unsqueeze(-1)

Let's look at an example prediction for the input text: "If you like the original, you'll love this movie."

In [12]:
sample_text = "If you like the original, you'll love this movie."

input_ids, ref_input_ids, sep_id = construct_input_ref_pair(sample_text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [13]:
print(f'tokenized text: {all_tokens}')
print(f'text as indices: {indices}')

tokenized text: ['[CLS]', 'If', 'you', 'like', 'the', 'original', ',', 'you', "'", 'll', 'love', 'this', 'movie', '.', '[SEP]']
text as indices: [101, 1409, 1128, 1176, 1103, 1560, 117, 1128, 112, 1325, 1567, 1142, 2523, 119, 102]


In [14]:
# Check predict output.
print(f'model predict: {model(input_ids)[0]}')

# Predict using custom forward pass.
print(f'custom forward pass: {custom_forward(input_ids)}')

model predict: tensor([[-3.2322,  3.4528]], device='cuda:0', grad_fn=<AddmmBackward>)
custom forward pass: tensor([0.9988], device='cuda:0', grad_fn=<UnsqueezeBackward0>)


## Interpret model predicitons using Captum's Layer Integrated Gradients method

Create the Layer Integrated Gradients object and compute attributions.

In [15]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [16]:
lig_attributions, delta = lig.attribute(inputs=input_ids,
                                        baselines=ref_input_ids,
                                        n_steps=700,
                                        internal_batch_size=3,
                                        return_convergence_delta=True)

In [17]:
predict_sentiment = torch.argmax(model(input_ids)[0]).cpu().numpy()
positive_prob = custom_forward(input_ids).detach().cpu().numpy()

print(f'Input example: {sample_text}')
print(f'Sentiment:  {str(predict_sentiment)}, Probability positive: {str(positive_prob[0])}')

Input example: If you like the original, you'll love this movie.
Sentiment:  1, Probability positive: 0.9987521


In [18]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [19]:
attributions_sum = summarize_attributions(lig_attributions)

In [20]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(word_attributions=attributions_sum,
                                        pred_prob=torch.softmax(model(input_ids)[0], dim = 1)[0][1],
                                        pred_class=torch.argmax(model(input_ids)[0]),
                                        true_class=1,
                                        attr_class=sample_text,
                                        attr_score=attributions_sum.sum(),       
                                        raw_input_ids=all_tokens,
                                        convergence_score=delta)


In [21]:
viz.visualize_text([score_vis])
plt.show()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),"If you like the original, you'll love this movie.",1.04,"[CLS] If you like the original , you ' ll love this movie . [SEP]"
,,,,


The negative attribution for 'you like' goes against our intuition, but the positive attibution for "love" seems to match what we would expect.

Copyright 2022 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License