# Interpreting BERT Models (Part 1)

In this notebook we demonstrate how to interpret Bert models using  `Captum` library. In this particular case study we focus on a fine-tuned Question Answering model on SQUAD dataset using transformers library from Hugging Face: https://huggingface.co/transformers/

We show how to use interpretation hooks to examine and better understand embeddings, sub-embeddings, bert, and attention layers. 

Note: Before running this tutorial, please install `seaborn`, `pandas` and `matplotlib`, `transformers`(from hugging face) python packages.

In [None]:
print(3)

In [17]:
import os
import sys
import json

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict

import torch
import torch.nn as nn

from transformers import BertTokenizer, BertModel, BertConfig

import sys
sys.path.append("/home/ubuntu/school_reviews/school_reviews_bert/src/models/base/")
from bert_models import MeanBertForSequenceRegression, RobertForSequenceRegression

### Define helper functions

In [18]:
class AdaptedMeanBertForSequenceRegression(nn.Module):
        def __init__(self, config, hid_dim=768, num_output=1):
                super(AdaptedMeanBertForSequenceRegression, self).__init__()
                self.config = config
                self.bert = BertModel.from_pretrained('bert-base-uncased', output_attentions=config.output_attentions)
                for name, param in self.bert.named_parameters():
                        if 'layer.11' not in name and 'pooler' not in name:
                                param.requires_grad=False
                        # param.requires_grad = False

                self.fc1 = nn.Linear(config.hidden_size, hid_dim)
                self.relu = torch.nn.ReLU()
                self.output_layer = nn.Linear(hid_dim, num_output)

                self.dropout = nn.Dropout(config.hidden_dropout_prob)

        '''
                input_ids = n_sent x max_len
        '''
        def forward(self, input_ids, attention_mask=None):
                outputs = self.bert(input_ids, attention_mask=attention_mask) # [n_sent, dim]
                sent_embs = self.dropout(outputs[0].mean(dim=1)) # [n_sent, config.hidden_size]
                sent_embs = sent_embs.mean(dim=0) # [1, config.hidden_size]
                return self.output_layer(self.relu(self.fc1(sent_embs)))

In [19]:
def bert_forward_wrapper(input_ids, attention_mask=None, position=0):
    return model(input_ids, attention_mask=attention_mask)

def normalize_attributions(attributions, percentile):
    curr_attributions = attributions.cpu().numpy()
    vmax = np.percentile(curr_attributions, percentile)
    vmin = np.min(curr_attributions)
#     normalized_attributions = np.clip((curr_attributions - vmin) / (vmax - vmin), 0, 1)
    normalized_attributions = (curr_attributions - vmin) / (vmax - vmin)
    return torch.Tensor(normalized_attributions)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

def visualize_text(datarecords):
    dom = ["<table width: 100%>"]
    rows = [
        "<th>Attribution Score</th>"
        "<th>Word Importance</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    viz.format_classname("{0:.2f}".format(datarecord.attr_score)),
                    viz.format_word_importances(
                        datarecord.raw_input, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    dom.append("".join(rows))
    dom.append("</table>")
    display(viz.HTML("".join(dom)))

In [25]:
# Load data
import pickle
prepared_data_file = '/home/ubuntu/school_reviews/school_reviews_bert/data/Parent_gs_comments_by_school_mn_avg_eb_1.7682657723517046.p'

with open(prepared_data_file, 'rb') as f:
     all_input_ids, labels_test_score, attention_masks, sentences_per_school = pickle.load(f, encoding='latin1')

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(device)
print (torch.cuda.is_available())

cuda
True


In [36]:
# Load model
# dropout_0.3-hid_dim_256-lr_0.0001-model_type_meanbert-outcome_mn_avg_eb
model_path = '/home/ubuntu/school_reviews/school_reviews_bert/saved_models/e7_loss1.0341.pt'

config = BertConfig(output_attentions=True, hidden_dropout_prob=0.3, attention_probs_dropout_prob=0.3)
# model = MeanBertForSequenceRegression(config, hid_dim=256, num_output=1)
model = AdaptedMeanBertForSequenceRegression(config, hid_dim=256, num_output=1)
sys.path.append("/home/ubuntu/school_reviews/school_reviews_bert/src/models/base/")
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
from collections import OrderedDict
updated_state_dict = OrderedDict()
for k in state_dict:
    curr_key = k
    if curr_key.startswith(('model.bert', 'model.fc1', 'model.output_layer')):
        curr_key = curr_key.split('model.')[1]
    updated_state_dict[curr_key] = state_dict[k]
    
model.load_state_dict(updated_state_dict)

model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Should be loading from model_path, but again, model wasn't saved with huggingface's save...() function
# tokenizer = BertTokenizer.from_pretrained(model_path)

In [None]:
from captum.attr import TokenReferenceBase
from captum.attr import IntegratedGradients, LayerIntegratedGradients
from captum.attr import visualization as viz

data_splits = ['validation', 'train']
all_summarized_attr = []
input_ids_for_attr = []
count = 0

internal_batch_size = 12
n_steps = 48

OUTPUT_FILE = '/home/ubuntu/school_reviews/interp/attributions/mn_avg_eb_e7_loss1.0341.pt/{}_{}_loss_{}.json'

for d in data_splits:

    n_schools = torch.LongTensor(all_input_ids[d]).size(0)
    print ("num schools {} for {} split".format(n_schools, d))
    
    for i in range(0, n_schools):
        
        print (d, i)

#         if count == 1: break
        count += 1

        # Prepare data
        input_ids = torch.LongTensor([all_input_ids[d][i]]).squeeze(0).to(device)
        label_t = torch.tensor([labels_test_score[d][i]]).to(device)
        input_mask = torch.tensor([attention_masks[d][i]]).squeeze(0).to(device)
        
        pred = model(input_ids, attention_mask=input_mask)
        mse = (pred.item() - label_t.item()) ** 2
        
        # Generate reference sequence for integrated gradients
        ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
        token_reference = TokenReferenceBase(reference_token_idx=ref_token_id)
        ref_input_ids = token_reference.generate_reference(input_ids.size(0), device=device).unsqueeze(1).repeat(1, input_ids.size(1)).long()

        # Compute integrated gradients
        lig = LayerIntegratedGradients(bert_forward_wrapper, model.bert.embeddings)
        attributions, conv_delta = lig.attribute(
            inputs=input_ids, 
            baselines=ref_input_ids,
            additional_forward_args=(input_mask, 0), 
            internal_batch_size=internal_batch_size,
            n_steps=n_steps,
            return_convergence_delta=True)

        # Summarize attributions and output
        summarized_attr = summarize_attributions(attributions).squeeze(0)
        n_sent = summarized_attr.size(0)
        attr_for_school_sents = defaultdict(dict)
        for j in range(0, n_sent):
            indices = input_ids[j].detach().squeeze(0).tolist()
            all_tokens = tokenizer.convert_ids_to_tokens(indices)
            attr_for_school_sents[j]['tokens'] = all_tokens
            attr_for_school_sents[j]['attributions'] = summarized_attr[j].tolist()
            assert (len(attr_for_school_sents[j]['tokens']) == len(attr_for_school_sents[j]['attributions']))
#         print (json.dumps(attr_for_school_sents, indent=4))
        f = open(OUTPUT_FILE.format(i, d, mse), 'w')
        f.write(json.dumps(attr_for_school_sents, indent=4))
        f.close()
        
#        all_summarized_attr.append(summarize_attributions(attributions).squeeze(0))
#        input_ids_for_attr.append(input_ids)

num schools 5354 for validation split
validation 0
validation 1
validation 2
validation 3
validation 4
validation 5
validation 6
validation 7
validation 8
validation 9
validation 10
validation 11
validation 12
validation 13
validation 14
validation 15
validation 16
validation 17
validation 18
validation 19
validation 20
validation 21
validation 22
validation 23
validation 24
validation 25
validation 26
validation 27
validation 28
validation 29
validation 30
validation 31
validation 32
validation 33
validation 34
validation 35
validation 36
validation 37
validation 38
validation 39
validation 40
validation 41
validation 42
validation 43
validation 44
validation 45
validation 46
validation 47
validation 48
validation 49
validation 50
validation 51
validation 52
validation 53
validation 54
validation 55
validation 56
validation 57
validation 58
validation 59
validation 60
validation 61
validation 62
validation 63
validation 64
validation 65
validation 66
validation 67
validation 68
valida

validation 551
validation 552
validation 553
validation 554
validation 555
validation 556
validation 557
validation 558
validation 559
validation 560
validation 561
validation 562
validation 563
validation 564
validation 565
validation 566
validation 567
validation 568
validation 569
validation 570
validation 571
validation 572
validation 573
validation 574
validation 575
validation 576
validation 577
validation 578
validation 579
validation 580
validation 581
validation 582
validation 583
validation 584
validation 585
validation 586
validation 587
validation 588
validation 589
validation 590
validation 591
validation 592
validation 593
validation 594
validation 595
validation 596
validation 597
validation 598
validation 599
validation 600
validation 601
validation 602


In [20]:
curr_school_viz = []
attributions_sum = all_summarized_attr[0]
input_ids = input_ids_for_attr[0]
n_sent = attributions_sum.size(0)
for i in range(0, n_sent):
        print (i)
        indices = input_ids[i].detach().squeeze(0).tolist()
        all_tokens = tokenizer.convert_ids_to_tokens(indices)

        vis = viz.VisualizationDataRecord(
                        attributions_sum[i],
                        None,
                        None,
                        None,
                        None,
                        attributions_sum[i].sum(),       
                        all_tokens,
                        conv_delta)

        curr_school_viz.append(vis)

NameError: name 'all_summarized_attr' is not defined

In [19]:
visualize_text(curr_school_viz)

0,1
-0.08,[CLS] this school is horrible . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
,
0.06,[CLS] they are prejudice against male students . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
,
-0.06,[CLS] knee jerk reactions to all imperfect behaviour from talking to arguing with other students . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
,
0.01,[CLS] they never address issues with boys and punish . . . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
,
-0.02,[CLS] more . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
,


In [None]:
from IPython.display import Image
Image(filename='img/bert/visuals_of_start_end_predictions.png')