<a href="https://colab.research.google.com/github/josbex/HS-detection_in_social_media_posts/blob/master/Interpretation_of_BERT_using_captum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interpretation of BertForSequenceClassification in captum

The orignial notebook this is based on can be found here:
https://colab.research.google.com/drive/1Lw3JTZio03VwPvSVFzLJmZ52oBRpo9ZM 

In [None]:
# Install dependencies
!pip install transformers
!pip install captum

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
[K     |████████████████████████████████| 778kB 2.7MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 15.9MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 16.8MB/s 
[?25hCollecting tokenizers==0.8.1.rc1
[?25l  Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB

In [None]:
import captum
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

In [None]:
import pandas as pd
import numpy as np
import csv

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [None]:
# Load a trained model and vocabulary that you have fine-tuned
def load_model(dir):
  output_dir = "./gdrive/My Drive/thesis/model/" + dir
  model = BertForSequenceClassification.from_pretrained(output_dir, output_attentions=True)
  tokenizer = BertTokenizer.from_pretrained(output_dir)
  # Copy the model to the GPU.
  model.to(device)
  model.eval()
  model.zero_grad()
  return model, tokenizer

In [None]:
# load model and tokenizer
model, tokenizer = load_model("/learn_rate_5/model_save")

In [None]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [None]:
def predict(inputs):
    #print('model(inputs): ', model(inputs))
    return model(inputs)[0]

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [None]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [None]:
def custom_forward(inputs):
    preds = predict(inputs)
    #return torch.softmax(preds, dim = 1)[:, 0] # for negative attribution, 
    return torch.softmax(preds, dim = 1)[:, 1] #<- for positive attribution

In [None]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [None]:
def get_dataset(filename):
  df = pd.read_csv("/content/gdrive/My Drive/thesis/" + filename + ".tsv", sep="\t") 
  return df

In [None]:
def get_test_tweet_by_index(index):
  df = get_dataset("dataset/test_data")
  tweets = df.tweet.values
  labels = df.label.values
  print("tweet: " + str(tweets[index]) + " label: " + str(labels[index]))
  return tweets[index], labels[index]

def get_prediction_type_indices(df):
  tp_indices = df.loc[df['Type'] == 'tp'].Index.values
  tn_indices = df.loc[df['Type'] == 'tn'].Index.values
  fp_indices = df.loc[df['Type'] == 'fp'].Index.values
  fn_indices = df.loc[df['Type'] == 'fn'].Index.values
  return tp_indices, tn_indices, fp_indices, fn_indices

In [None]:
def tokenize_tweet(tweet):
  input_ids, ref_input_ids, sep_id = construct_input_ref_pair(tweet, ref_token_id, sep_token_id, cls_token_id)
  token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
  position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
  attention_mask = construct_attention_mask(input_ids)
  indices = input_ids[0].detach().tolist()
  all_tokens = tokenizer.convert_ids_to_tokens(indices)
  return input_ids, ref_input_ids, sep_id, token_type_ids, ref_token_type_ids, position_ids, ref_position_ids, attention_mask, indices, all_tokens

In [None]:
df = get_dataset("model/predictions/pred_indices")
tp_indices, tn_indices, fp_indices, fn_indices = get_prediction_type_indices(df)
tweet, true_label = get_test_tweet_by_index(fp_indices[0])
input_ids, ref_input_ids, sep_id, token_type_ids, ref_token_type_ids, position_ids, ref_position_ids, attention_mask, indices, all_tokens = tokenize_tweet(tweet)

tweet: rap is a form of art ! used to express yourself freely . it does not gv the green light or excuse the behavior of acting like an animal ! she is not in the streets of the bx where violence is a way of living . elevate yourself boo and get on @user level for longevity ! queen $EMOJI$ label: 0


In [None]:
#saved_act = None
def save_act(module, inp, out):
  #global saved_act
  #saved_act = out
  return saved_act

hook = model.bert.embeddings.register_forward_hook(save_act)

In [None]:
hook.remove()

In [None]:
# Check predict output
custom_forward(torch.cat([input_ids]))
input_ids.shape

torch.Size([1, 71])

In [None]:
pred = predict(input_ids)
torch.softmax(pred, dim = 1)

tensor([[0.3083, 0.6917]], device='cuda:0', grad_fn=<SoftmaxBackward>)

In [None]:
# Check output of custom_forward
custom_forward(input_ids)

tensor([0.6917], device='cuda:0', grad_fn=<SelectBackward>)

In [None]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=7000,
                                    internal_batch_size=5,
                                    return_convergence_delta=True)

In [None]:
score = predict(input_ids)

print('Tweet: ', tweet)
print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \
      ', Probability positive: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))

Tweet:  rap is a form of art ! used to express yourself freely . it does not gv the green light or excuse the behavior of acting like an animal ! she is not in the streets of the bx where violence is a way of living . elevate yourself boo and get on @user level for longevity ! queen $EMOJI$
Sentiment: 1, Probability positive: 0.69170946


In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
attributions_sum = summarize_attributions(attributions)

In [None]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][0],
                                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                                        true_label,
                                        tweet,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)


In [None]:
print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,1 (0.31),rap is a form of art ! used to express yourself freely . it does not gv the green light or excuse the behavior of acting like an animal ! she is not in the streets of the bx where violence is a way of living . elevate yourself boo and get on @user level for longevity ! queen $EMOJI$,1.69,[CLS] rap is a form of art ! used to express yourself freely . it does not g ##v the green light or excuse the behavior of acting like an animal ! she is not in the streets of the b ##x where violence is a way of living . el ##eva ##te yourself boo and get on @ user level for longevity ! queen $ em ##oj ##i $ [SEP]
,,,,
