In [1]:
import json
import logging
import os
import numpy as np
import torch
from transformers import BertTokenizer
from ts.torch_handler.base_handler import BaseHandler
from captum.attr import IntegratedGradients
from captum.attr import InterpretableEmbeddingBase, TokenReferenceBase
from captum.attr import visualization
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
from news_classifier import BertNewsClassifier
import torch.nn.functional as F
import torch.nn as nn 

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
data = {"data": ["This year business is good"]}
text = data["data"][0]
print(text)

This year business is good


In [4]:
def compute_bert_outputs(model_bert, embedding_input, attention_mask=None, head_mask=None):
    if attention_mask is None:
        attention_mask = torch.ones(embedding_input.shape[0], embedding_input.shape[1]).to(embedding_input)

    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

    extended_attention_mask = extended_attention_mask.to(dtype=next(model_bert.parameters()).dtype) # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    if head_mask is not None:
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(model_bert.config.num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
        head_mask = head_mask.to(dtype=next(model_bert.parameters()).dtype) # switch to fload if need + fp16 compatibility
    else:
        head_mask = [None] * model_bert.config.num_hidden_layers

    encoder_outputs = model_bert.encoder(embedding_input,
                                         extended_attention_mask,
                                         head_mask=head_mask)
    sequence_output = encoder_outputs[0]
    pooled_output = model_bert.pooler(sequence_output)
    outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
    return outputs 


In [5]:
model_dir =os.getcwd()
# Read model serialize/pt file
model_pt_path = os.path.join(model_dir, "state_dict.pth")
# Read model definition file
VOCAB_FILE = os.path.join(model_dir, "bert_base_uncased_vocab.txt")
if not os.path.isfile(VOCAB_FILE):
    raise RuntimeError("Missing the vocab file")

class_mapping_file = os.path.join(model_dir, "class_mapping.json")
state_dict = torch.load(model_pt_path, map_location=device)
model = BertNewsClassifier()
model.load_state_dict(state_dict)
model.to(device)
model.eval()


BertNewsClassifier(
  (train_acc): Accuracy()
  (val_acc): Accuracy()
  (test_acc): Accuracy()
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias

In [6]:
import torch.nn as nn

class AGNewsmodelWrapper(nn.Module):
    
    def __init__(self, model):
        super(AGNewsmodelWrapper, self).__init__()
        self.model = model
        
    def forward(self, embeddings):        
        outputs =compute_bert_outputs(self.model.bert_model,embeddings)
        pooled_output = outputs[1]
        output = F.relu(self.model.fc1(pooled_output))
        output = self.model.drop(output)
        output = self.model.out(output)
        print("shape of final output",output.shape)
        return output
    

In [7]:
ag_model_wrapper = AGNewsmodelWrapper(model)
ig_1 = IntegratedGradients(ag_model_wrapper)

In [8]:
tokenizer= BertTokenizer(VOCAB_FILE)
ag_model_wrapper.eval()
ag_model_wrapper.zero_grad()

input_ids= torch.tensor([tokenizer.encode(text, add_special_tokens=True)])
input_embedding_test = ag_model_wrapper.model.bert_model.embeddings(input_ids)

In [9]:
input_embedding_test.shape

torch.Size([1, 7, 768])

In [10]:
preds= ag_model_wrapper(input_embedding_test)
out = np.argmax(preds.cpu().detach(), axis=1)
out =(out.item())
preds

shape of final output torch.Size([1, 4])


tensor([[-3.0944, -3.3159,  2.8459,  0.4021]], grad_fn=<AddmmBackward>)

In [11]:
def score_func(o):
    output = F.softmax(o, dim=1)
    pre_pro= np.max(output.detach().numpy())
    return pre_pro

score_func(preds)

0.91611236

In [12]:
text_ref ="PAD PAD PAD PAD PAD"
input_id_ref= torch.tensor([tokenizer.encode(text_ref, add_special_tokens=True)])
input_embedding_ref = ag_model_wrapper.model.bert_model.embeddings(input_id_ref)
input_embedding_ref.shape
# baselines =input_embedding_ref

torch.Size([1, 7, 768])

In [13]:
attributions, delta = ig_1.attribute(input_embedding_test, n_steps=500, return_convergence_delta=True, target=2,baselines =input_embedding_ref)

shape of final output torch.Size([500, 4])
shape of final output torch.Size([1, 4])
shape of final output torch.Size([1, 4])


In [14]:
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].numpy().tolist())    
tokens

['[CLS]', 'this', 'year', 'business', 'is', 'good', '[SEP]']

In [17]:
vis_data_records_base=[]

In [18]:
def add_attributions_to_visualizer(attributions, tokens, pred_prob, pred_class, true_class,attr_class, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().numpy()
    
    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred_prob,
                            pred_class,
                            true_class,
                            attr_class,
                            attributions.sum(),       
                            tokens,
                            delta))


In [19]:
add_attributions_to_visualizer(attributions, tokens, score_func(preds), out, 2,2, delta, vis_data_records_base)

In [20]:
vis_data_records_base

[<captum.attr._utils.visualization.VisualizationDataRecord at 0x2484d5e8880>]

In [21]:
visualization.visualize_text(vis_data_records_base)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,2 (0.92),2.0,0.34,[CLS] this year business is good [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,2 (0.92),2.0,0.34,[CLS] this year business is good [SEP]
,,,,
