# Interpreting a classification model using LIME

In [40]:
import pandas as pd
import os
# import time

import numpy as np 
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange, tqdm_notebook

from data_processing import LeftRightProcessor, convert_examples_to_features

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

### Constants:

In [41]:
output_dir = "./output"
data_dir = "./data/data"
device = torch.device("cpu")
num_labels = 2
max_seq_length = 128
bert_model = "bert-base-uncased"
cased = False
seed = 42
num_test_examples = 1000

### Helper Functions:

In [42]:
def get_dataset(features):
    """ Converts a set of features to a TensorDataset """
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    all_case_ids = torch.tensor([f.case_id for f in features], dtype=torch.long)
    return TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_case_ids)

def get_eval_dataloader(eval_features, eval_batch_size=8):
    """ parses test examples and prepares them into a DataLoader """
    eval_dataset = get_dataset(eval_features)
    eval_sampler = SequentialSampler(eval_dataset)
    return DataLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size)

In [43]:
def text_to_tensors(text, max_seq_length, tokenizer):
    tokens = tokenizer.tokenize(text)
    if len(tokens) > max_seq_length - 2 :
         tokens = tokens[:(max_seq_length - 2)]
    tokens = ["[CLS]"] + tokens + ["[SEP]"]
    segment_ids = [0] * len(tokens)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    input_mask += padding
    segment_ids += padding


    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    input_id = torch.tensor([input_ids], dtype=torch.long).to(device)
    input_mask = torch.tensor([input_mask], dtype=torch.long).to(device)
    segment_id = torch.tensor([segment_ids], dtype=torch.long).to(device)
    
    return input_id, input_mask, segment_id

In [44]:
def model_text(text_lst):
    probs = []
    model.eval()
    for text in text_lst:
        input_ids, input_mask, segment_ids = text_to_tensors(text, max_seq_length, tokenizer)
        with torch.no_grad():
            logits = model(input_ids, segment_ids, input_mask)
            probs.append([float(n) for n in logits[0]])
    return np.array(probs)

# model_text([test_examples[0].text])

In [45]:
class Prediction:
    high_confidence_cutoff = 4
    low_confidence_cutoff = 1
    def __init__(self, predicted, real, logits, case_id):
        self.predicted = predicted
        self.real = real
        self.logits = logits
        self.case_id = case_id

    def __repr__(self):
        return "predicted: {}, real: {}, logits: {}, case_id: {}".format(self.predicted, self.real, self.logits, self.case_id)
    
    def get_confidence(self):
        """ Returns the magnitude of the range of logits """
        return abs(self.logits[1] - self.logits[0])
    
    def get_confidence_group(self):
        """ Returns:
        -1 : Confident Wrong
        0 : Unsure
        1 : Confident Correct
        None : No clear distinction
        """
        confidence = self.get_confidence()
        if confidence > self.high_confidence_cutoff:
            return 1 if self.is_correct() else -1
        elif confidence < self.low_confidence_cutoff:
            return 0
        else:
            return None
    
    def is_correct(self):
        return self.predicted == self.real
    
    @classmethod
    def set_high_confidence_cutoff(cls, cutoff):
        cls.high_confidence_cutoff = cutoff
        
    @classmethod
    def set_low_confidence_cutoff(cls, cutoff):
        cls.low_confidence_cutoff = cutoff

In [46]:
def predict_individually(test_features, model, verbosity=0):
    predictions = []
    dl = get_eval_dataloader(test_features, eval_batch_size=1)
    model.eval()
    for input_ids, input_mask, segment_ids,label_ids, case_ids in dl: #tqdm_notebook(dl, desc="Evaluating"):
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)
        case_ids = case_ids.to(device)

        with torch.no_grad():
            logits = model(input_ids, segment_ids, input_mask)
        logits = logits.detach().cpu().numpy()
        outputs = np.argmax(logits, axis=1)
        predictions.append(Prediction(outputs[0], label_ids[0], logits[0], case_ids[0]))
        if verbosity:
            print("predicted:", outputs, "real_label:", label_ids, "probs:", logits, "id:", case_ids)
    return predictions

### Prep Model For Classification:

In [47]:
tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=(not cased))

03/22/2019 13:40:44 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /n/home12/jdcclark/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


In [48]:
processor = LeftRightProcessor(
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    data_dir=data_dir,
    seed=seed,
    train_batch_size=1, #shouldn't be used
    eval_batch_size=1
)
test_examples = processor.get_examples(num_test_examples, "test.csv")
# test_examples

In [19]:
test_features = convert_examples_to_features(test_examples, processor.get_labels(), max_seq_length, tokenizer)

#### Load Model From output_dir:

In [49]:
%%time
model = BertForSequenceClassification.from_pretrained(output_dir, num_labels=num_labels)
model.to(device)

03/22/2019 13:40:50 - INFO - pytorch_pretrained_bert.modeling -   loading archive file ./output
03/22/2019 13:40:50 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}



CPU times: user 2.99 s, sys: 264 ms, total: 3.25 s
Wall time: 3.27 s


### Using the Model to predict each case:

In [None]:
%%time 
predictions = predict_individually(test_features, model, verbosity=0)


In [None]:
confident_wrong = []
unsure = []
confident_correct = []
for prediction in predictions:
    category =  prediction.get_confidence_group()
    if category == 1:
        confident_correct.append(prediction)
    if category == 0:
        unsure.append(prediction)
    if category == -1:
        confident_wrong.append(prediction)


In [None]:
predictions.sort(key=Prediction.get_confidence, reverse=True)
incorrect_predictions = list(filter(lambda x: not Prediction.is_correct(x), predictions))
incorrect_predictions[-1]

### LIME Interpreter:

In [50]:
import lime
from lime import lime_text
from lime.lime_text import LimeTextExplainer
import json

class_names = ["left", "right"]

In [51]:
import re
exp_tokenizer = lambda doc: re.compile(r"(?u)\b\w\w+\b").findall(doc)

In [52]:
with open("./output/features.json") as f:
    vec_features = [json.loads(line) for line in f]
id_to_text = {feature["linex_index"]:feature["text"] for feature in vec_features}

In [53]:
explainer = LimeTextExplainer(class_names=class_names, split_expression=exp_tokenizer)

In [58]:
ex = id_to_text[615791547]

In [None]:
%%time
exp = explainer.explain_instance(ex, model_text, num_features=10, num_samples=400)

In [None]:
exp.as_list()

In [None]:
%matplotlib inline
exp.show_in_notebook(text=True)