In [1]:
%load_ext autoreload
%autoreload 2

import shutil
import numpy as np
import logging
import shap
import torch
import json
from numpy import dot
from numpy.linalg import norm
from urllib import request
from pytorch_pretrained_bert import BertModel, BertTokenizer

from interpret_text.msra.MSRAExplainer import MSRAExplainer

In [2]:
DATA_FOLDER = "./temp"
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")

A function to generate embeddings for BERT Input

In [3]:
def embeddings_bert(text, device):
    # get the tokenized words.
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    words = ["[CLS]"] + tokenizer.tokenize(text) + ["[SEP]"]
    tokenized_ids = tokenizer.convert_tokens_to_ids(words)
    segment_ids = [0 for _ in range(len(words))]
    token_tensor = torch.tensor([tokenized_ids], device=device)
    segment_tensor = torch.tensor([segment_ids], device=device)
    x_bert = model.embeddings(token_tensor, segment_tensor)[0]
    return x_bert

Let's load the BERT base model with the saved finetuned parameters

In [None]:
#load the finetuned parameters
model_state_dict = torch.load("models/model.pth")
#Load BERT base model with the finetuned parameters
model = BertModel.from_pretrained("bert-base-uncased", state_dict=model_state_dict)
model.to(device)

for param in model.parameters():
    param.requires_grad = False
model.eval()

Now we generate the embeddings for the input text and initialize the interpreter. We also calculate the regularization parameter required by the MSR Asia Explainer using the function provided by the Explainer class.

In [5]:
text = "rare bird has more than enough charm to make it memorable."
embedded_input = embeddings_bert(text, device)
interpreter_msra = MSRAExplainer(device=device)
regularization = interpreter_msra.getRegularizationBERT(model=model)

We then call explain_local on the interpreter.

In [6]:
explanation_msra = interpreter_msra.explain_local(model=model, embedded_input=embedded_input, regularization=regularization)
print(explanation_msra.local_importance_values)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [04:54<00:00, 17.00it/s]


[0.1879063993692398, 0.14912039041519165, 0.1531413346529007, 0.2352224886417389, 0.21313492953777313, 0.21833769977092743, 0.19894209504127502, 0.13736814260482788, 0.2685736417770386, 0.23845000565052032, 0.25325438380241394, 0.1364051103591919, 0.29170724749565125, 0.37515968084335327]


Basic visualization until the visualization dashboard is fully integrated as a python widget

In [7]:
interpreter_msra.visualize(text)