In [3]:
import os
import sys

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
# replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
model_path = '/mnt/data/medhas/glue_experiments/bert-base-cased/20200714/SST-2'

# load model
model = BertForSequenceClassification.from_pretrained(model_path)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained(model_path)

In [8]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [334]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id] 
    
    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)+1

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [378]:
#text = "that loves its characters and communicates something rather beautiful about human nature"
text = "you do n't have to know about music to appreciate the film 's easygoing blend of comedy and romance."

In [379]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [380]:
output = model(input_ids, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )

In [381]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    return model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )

def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    pred = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    pred = pred[0].max(1).values
    return pred

In [382]:
scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)

In [383]:
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

attributions, delta = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask),
                                  return_convergence_delta=True)

In [384]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [385]:
attributions_sum = summarize_attributions(attributions)

In [386]:
scores

(tensor([[-3.7749,  4.4011]], device='cuda:0', grad_fn=<AddmmBackward>),)

In [387]:
attributions_sum

tensor([ 0.0000,  0.1540,  0.0038,  0.0286, -0.0107, -0.1081,  0.2745,  0.1513,
        -0.0718, -0.0108,  0.0232,  0.0245,  0.7257, -0.1646, -0.0772, -0.0822,
        -0.0316,  0.3204,  0.0993,  0.3193,  0.0639,  0.0940, -0.0685,  0.2435,
         0.0129,  0.0000], device='cuda:0', dtype=torch.float64,
       grad_fn=<DivBackward0>)

In [389]:
# storing couple samples in an array for visualization purposes
start_position_vis = viz.VisualizationDataRecord(
                        word_attributions=attributions_sum,
                        pred_prob=torch.max(torch.softmax(scores[0][0], dim=0)),
                        pred_class=torch.argmax(scores[0]),
                        true_class=torch.argmax(scores[0]),
                        attr_class=torch.argmax(scores[0]),
                        attr_score=attributions_sum.sum(),       
                        raw_input=all_tokens,
                        convergence_score=delta_start)
print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([start_position_vis])


[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),1.0,1.91,[CLS] you do n ' t have to know about music to appreciate the film ' s easy ##going blend of comedy and romance . [SEP]
,,,,
1.0,1 (1.00),1.0,1.91,[CLS] you do n ' t have to know about music to appreciate the film ' s easy ##going blend of comedy and romance . [SEP]
,,,,


In [391]:
def get_vis(text):
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    
    scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)
    
    lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

    attributions, delta = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask),
                                  return_convergence_delta=True)
    
    attributions_sum = summarize_attributions(attributions)
    
    # storing couple samples in an array for visualization purposes
    start_position_vis = viz.VisualizationDataRecord(
                        word_attributions=attributions_sum,
                        pred_prob=torch.max(torch.softmax(scores[0][0], dim=0)),
                        pred_class=torch.argmax(scores[0]),
                        true_class=torch.argmax(scores[0]),
                        attr_class=torch.argmax(scores[0]),
                        attr_score=attributions_sum.sum(),       
                        raw_input=all_tokens,
                        convergence_score=delta_start)
    
    return start_position_vis
    
    

In [398]:
vis1 = get_vis( "you do n't have to know about music to appreciate the film 's easygoing blend of comedy and romance.")

In [399]:
vis2 = get_vis("you wonder why enough was n't just a music video rather than a full-length movie.")

In [400]:
vis3 = get_vis("it offers little beyond the momentary joys of pretty and weightless intellectual entertainment.")

In [401]:
vis4 = get_vis("if you 're hard up for raunchy college humor , this is your ticket right here.")

In [402]:
vis5 = get_vis("allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker.")

In [403]:
print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([vis1, vis2, vis3, vis4, vis5])

[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),1.0,1.91,[CLS] you do n ' t have to know about music to appreciate the film ' s easy ##going blend of comedy and romance . [SEP]
,,,,
0.0,0 (0.98),0.0,1.74,[CLS] you wonder why enough was n ' t just a music video rather than a full - length movie . [SEP]
,,,,
0.0,0 (0.94),0.0,0.88,[CLS] it offers little beyond the moment ##ary joy ##s of pretty and weight ##less intellectual entertainment . [SEP]
,,,,
1.0,1 (0.79),1.0,0.67,"[CLS] if you ' re hard up for r ##au ##nch ##y college humor , this is your ticket right here . [SEP]"
,,,,
1.0,1 (1.00),1.0,1.39,[CLS] allows us to hope that no ##lan is poised to em ##bark a major career as a commercial yet in ##vent ##ive filmmaker . [SEP]
,,,,
