# Interpreting BERT Models (Part 1)

In this notebook we demonstrate how to interpret Bert models using  `Captum` library. In this particular case study we focus on a fine-tuned Question Answering model on SQUAD dataset using transformers library from Hugging Face: https://huggingface.co/transformers/

We show how to use interpretation hooks to examine and better understand embeddings, sub-embeddings, bert, and attention layers. 

Note: Before running this tutorial, please install `seaborn`, `pandas` and `matplotlib`, `transformers`(from hugging face) python packages.

In [None]:
# modified from https://captum.ai/tutorials/Bert_SQUAD_Interpret
import os
import sys
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import torch
import torch.nn as nn

from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig
from tqdm.notebook import tqdm
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer


from models import BertForGuilt

In [None]:
device = torch.device("cpu")

The first step is to fine-tune BERT model on SQUAD dataset. This can be easiy accomplished by following the steps described in hugging face's official web site: https://github.com/huggingface/transformers#run_squadpy-fine-tuning-on-squad-for-question-answering 

Note that the fine-tuning is done on a `bert-base-uncased` pre-trained model.

After we pretrain the model, we can load the tokenizer and pre-trained BERT model using the commands described below. 

In [None]:
# load tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

A helper function to perform forward pass of the model and make predictions.

In [None]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None, training_head=None):
    return model(input_ids=inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, training_head = training_head, device=device)


def guilt_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, training_head=None):
    pred = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask, training_head=training_head)
    return pred[0]
# input_ids=None, attention_mask=None, token_type_ids=None,
#                 position_ids=None, head_mask=None, inputs_embeds=None, labels=None, training_head=[-1], with_token_cls=False, token_labels=None,  device='cpu',  highlight_ratio=None

Defining a custom forward function that will allow us to access the start and end postitions of our prediction using `position` input argument.

Let's compute attributions with respect to the `BertEmbeddings` layer.

To do so, we need to define baselines / references, numericalize both the baselines and the inputs. We will define helper functions to achieve that.

The cell below defines numericalized special tokens that will be later used for constructing inputs and corresponding baselines/references.

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

Below we define a set of helper function for constructing references / baselines for word tokens, token types and position ids. We also provide separate helper functions that allow to construct the sub-embeddings and corresponding baselines / references for all sub-embeddings of `BertEmbeddings` layer.

In [None]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)

def construct_input_ref_token_type_pair(input_ids):
    # return token_type_ids, ref_token_type_ids
    seq_len = input_ids.size(1)
    return torch.zeros(seq_len, device=device, dtype=torch.long), torch.zeros(seq_len, device=device, dtype=torch.long)

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids

def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    
    return input_embeddings, ref_input_embeddings
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = (attributions - torch.mean(attributions))/ torch.norm(attributions)
    return attributions

Let's define the `question - text` pair that we'd like to use as an input for our Bert model and interpret what the model was forcusing on when predicting an answer to the question from given input text 

In [None]:
# replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
training_head = ['1']
model_path = '<path_to_model_dump>'

# load model
model = BertForGuilt.from_pretrained(model_path)
model.to(device)
model.eval()
model.zero_grad()

In [None]:
data = []
with open('<path_to_test_dataset>') as infile:
    for line in infile:
        data.append(json.loads(line))

Let's numericalize the question, the input text and generate corresponding baselines / references for all three sub-embeddings (word, token type and position embeddings) types using our helper functions defined above.

In [None]:
%%time
all_score_vises = {}
for row in tqdm(data[:10]):
    text = row['story_clean']
    ground_truth = np.mean([i.get('suspect_committedCrime', 0) for i in row['data'].values()])
    if np.isnan(ground_truth):
        continue
    input_ids, ref_input_ids  = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    score = predict(input_ids, attention_mask=attention_mask, training_head=training_head)

    lig = LayerIntegratedGradients(guilt_forward_func, model.bert.embeddings)

    attributions_score, delta_score = lig.attribute(inputs=input_ids,
                                      baselines=ref_input_ids,
                                      additional_forward_args=(None, None, None, training_head),
                                      return_convergence_delta=True)
    attributions_score_sum = summarize_attributions(attributions_score)

    # storing couple samples in an array for visualization purposes
    score_vis = viz.VisualizationDataRecord(
                            attributions_score_sum,
                            np.round(score[0].item(),3),
                            -1,
                            -1,
                            str(ground_truth),
                            attributions_score_sum.sum(),       
                            all_tokens,
                            delta_score)
    
    all_score_vises[row['story_id']] = score_vis



In [None]:
def highlight_parser(string, highlights, tokenizer, source):
    # return list of token, indices, and highlight score
    highlights = [i for i in highlights if i]
    if len(highlights) == 0:
        return None
    highlights = [hl[source if source else 0] for hl in highlights]
    assert len(string) == len(highlights[0])
    highlights = [[int(i) for i in hl] for hl in highlights]
    string_splited = string_split(string)
    wordpiece_tokens = sum([wordpiece_with_indices(tokenizer, tok, start) for tok, start, end in string_splited], [])
    highlights = [[np.mean(highlight[start:end]) for _, start, end in wordpiece_tokens] for highlight in highlights]
    highlight_mean = list(np.mean(highlights, axis=0))
    return highlight_mean


In [None]:
viz.visualize_text(all_score_vises.values())