# Interpreting a classification model using LIME

In [53]:
import pandas as pd
import os
# import time

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from classifier import InputExample, LeftRightProcessor, InputFeatures, convert_examples_to_features

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

Constants:

In [24]:
output_dir = "./output"
data_dir = "./data"
device = torch.device("cpu")
num_labels = 2
max_seq_length = 128
bert_model = "bert-base-uncased"
cased = False

In [21]:
processor = LeftRightProcessor()
num_test_examples = 10
test_examples = processor.get_test_examples(data_dir, num_test_examples)
# test_examples

In [28]:
tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=(not cased))
test_features = convert_examples_to_features(test_examples, processor.get_labels(), max_seq_length, tokenizer)
# test_features

LOAD MODEL FROM 'output_dir':

In [14]:
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)

config = BertConfig(output_config_file)
model = BertForSequenceClassification(config, num_labels=num_labels)
model.load_state_dict(torch.load(output_model_file, map_location=device))
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertInterme

In [38]:
def get_dataset(features):
    """ Converts a set of features to a TensorDataset """
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    return TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

def get_eval_dataloader(eval_features, eval_batch_size=8):
    """ parses test examples and prepares them into a DataLoader """
    eval_dataset = get_dataset(eval_features)
    eval_sampler = SequentialSampler(eval_dataset)
    return DataLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size)

In [61]:
dl = get_eval_dataloader(test_features, eval_batch_size=1)
model.eval()
for input_ids, input_mask, segment_ids, label_ids in dl: #tqdm(dl, desc="Evaluating"):
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    label_ids = label_ids.to(device)

    with torch.no_grad():
        tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
        logits = model(input_ids, segment_ids, input_mask)
        
    outputs = np.argmax(logits, axis=1)
    print(logits, outputs, label_ids)


tensor([[ 0.2288, -0.2597]]) tensor([0]) tensor([0])
tensor([[ 0.7061, -0.5765]]) tensor([0]) tensor([0])
tensor([[ 1.4858, -1.1538]]) tensor([0]) tensor([0])
tensor([[-0.8970,  0.7125]]) tensor([1]) tensor([1])
tensor([[ 0.0264, -0.1128]]) tensor([0]) tensor([1])
tensor([[ 0.6785, -0.5491]]) tensor([0]) tensor([0])
tensor([[ 1.0243, -0.7590]]) tensor([0]) tensor([0])
tensor([[ 1.3079, -0.9407]]) tensor([0]) tensor([0])
tensor([[-0.4188,  0.2724]]) tensor([1]) tensor([0])
tensor([[ 0.4698, -0.4222]]) tensor([0]) tensor([1])


In [76]:
def text_to_tensors(text, max_seq_length, tokenizer):
    tokens = tokenizer.tokenize(text)
    if len(tokens) > max_seq_length - 2 :
         tokens = tokens[:(max_seq_length - 2)]
    tokens = ["[CLS]"] + tokens + ["[SEP]"]
    segment_ids = [0] * len(tokens)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    input_mask += padding
    segment_ids += padding


    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    input_id = torch.tensor([input_ids], dtype=torch.long).to(device)
    input_mask = torch.tensor([input_mask], dtype=torch.long).to(device)
    segment_id = torch.tensor([segment_ids], dtype=torch.long).to(device)
    
    return input_id, input_mask, segment_id

In [140]:
def model_text(text_lst):
    probs = []
    model.eval()
    for text in text_lst:
        input_id, input_mask, segment_id = text_to_tensors(text, max_seq_length, tokenizer)
        with torch.no_grad():
            logits = model(input_ids, segment_ids, input_mask)
            probs.append([float(n) for n in logits[0]])
    return np.array(probs)

# model_text([test_examples[0].text])

array([[ 0.46975225, -0.42222288]])

In [81]:
import lime
from lime import lime_text
from lime.lime_text import LimeTextExplainer

class_names = ["left", "right"]

'he'

In [111]:
import re
exp_tokenizer = lambda doc: re.compile(r"(?u)\b\w\w+\b").findall(doc)
explainer = LimeTextExplainer(class_names=class_names, split_expression=exp_tokenizer)

In [134]:
np.array([[1, 2], [3, 4]])

array([[1, 2],
       [3, 4]])

In [None]:
exp = explainer.explain_instance(test_examples[0].text, model_text, num_features=6)