<a href="https://colab.research.google.com/github/josbex/HS-detection_in_social_media_posts/blob/master/Interpretation_of_BERT_using_captum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interpretation of BertForSequenceClassification in captum

The orignial notebook this is based on can be found here:
https://colab.research.google.com/drive/1Lw3JTZio03VwPvSVFzLJmZ52oBRpo9ZM 

In [None]:
# Install dependencies
!pip install transformers
!pip install captum

In [None]:
import captum
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [None]:
# Load a trained model and vocabulary that you have fine-tuned
def load_model(dir):
  output_dir = "./gdrive/My Drive/thesis/model/" + dir
  model = BertForSequenceClassification.from_pretrained(output_dir, output_attentions=True)
  tokenizer = BertTokenizer.from_pretrained(output_dir)
  # Copy the model to the GPU.
  model.to(device)
  model.eval()
  model.zero_grad()
  return model, tokenizer

In [None]:
# load model and tokenizer
model, tokenizer = load_model("/learn_rate_5/model_save")

In [None]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [None]:
def predict(inputs):
    #print('model(inputs): ', model(inputs))
    return model(inputs)[0]

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [None]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [None]:
def custom_forward(inputs):
    preds = predict(inputs)
    #return torch.softmax(preds, dim = 1)[:, 0] # for negative attribution, 
    return torch.softmax(preds, dim = 1)[:, 1] #<- for positive attribution

In [None]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [None]:
text =  "who is q wheres the server dump nike declasfisa democrats support antifa , muslim brotherhood , ms13 , isis , pedophilia , child trafficking , taxpayer funded abortion s , election fraud , sedition and treason ! ! ! lock them all up wwg 1 wga q anon @user url"
true_label = 1

In [None]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [None]:
#saved_act = None
def save_act(module, inp, out):
  #global saved_act
  #saved_act = out
  return saved_act

hook = model.bert.embeddings.register_forward_hook(save_act)

In [None]:
hook.remove()

In [None]:
# Check predict output
custom_forward(torch.cat([input_ids]))
input_ids.shape

torch.Size([1, 67])

In [None]:
pred = predict(input_ids)
torch.softmax(pred, dim = 1)

tensor([[0.3464, 0.6536]], device='cuda:0', grad_fn=<SoftmaxBackward>)

In [None]:
# Check output of custom_forward
custom_forward(input_ids)

tensor([0.6536], device='cuda:0', grad_fn=<SelectBackward>)

In [None]:
input_ids

tensor([[  101,  2040,  2003,  1053,  2073,  2015,  1996,  8241, 15653, 18368,
         11703,  8523,  8873,  3736,  8037,  2490,  3424,  7011,  1010,  5152,
         12865,  1010,  5796, 17134,  1010, 18301,  1010, 21877,  3527, 21850,
          6632,  1010,  2775, 11626,  1010, 26980,  6787, 11324,  1055,  1010,
          2602,  9861,  1010,  7367, 20562,  1998, 14712,   999,   999,   999,
          5843,  2068,  2035,  2039,  1059, 27767,  1015,  1059,  3654,  1053,
          2019,  2239,  1030,  5310, 24471,  2140,   102]], device='cuda:0')

In [None]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=7000,
                                    internal_batch_size=5,
                                    return_convergence_delta=True)

In [None]:
score = predict(input_ids)

print('Sentence: ', text)
print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \
      ', Probability positive: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))

Sentence:  who is q wheres the server dump nike declasfisa democrats support antifa , muslim brotherhood , ms13 , isis , pedophilia , child trafficking , taxpayer funded abortion s , election fraud , sedition and treason ! ! ! lock them all up wwg 1 wga q anon @user url
Sentiment: 1, Probability positive: 0.6536152


In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
attributions_sum = summarize_attributions(attributions)

In [None]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][0],
                                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                                        true_label,
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)


In [None]:
print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.35),"who is q wheres the server dump nike declasfisa democrats support antifa , muslim brotherhood , ms13 , isis , pedophilia , child trafficking , taxpayer funded abortion s , election fraud , sedition and treason ! ! ! lock them all up wwg 1 wga q anon @user url",2.04,"[CLS] who is q where ##s the server dump nike dec ##las ##fi ##sa democrats support anti ##fa , muslim brotherhood , ms ##13 , isis , pe ##do ##phi ##lia , child trafficking , taxpayer funded abortion s , election fraud , se ##dition and treason ! ! ! lock them all up w ##wg 1 w ##ga q an ##on @ user ur ##l [SEP]"
,,,,
