<img width="150" alt="Logo_ER10" src="https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png">

### Interpreting the sentence classification model with LIME

LIME (Local Interpretable Model-agnostic Explanations) is an explainable-AI method that aims to create an interpretable model that locally represents the classifier. For more details see the [LIME paper](https://arxiv.org/abs/1602.04938).

Note that this notebook was adapted from the [LIME/text tutorial for DIANNA](https://github.com/dianna-ai/dianna/blob/main/tutorials/lime_text.ipynb).

#### Colab Setup

In [1]:
import pandas as pd

from classify_text_with_inlegal_bert_xgboost import classify_texts

running_in_colab = 'google.colab' in str(get_ipython())
if running_in_colab:
  # install dianna
  !python3 -m pip install dianna[notebooks]
  
  # download data used in this demo
  import os 
  base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/tutorials/'
  paths_to_download = ['data/movie_reviews_word_vectors.txt', 'models/movie_review_model.onnx']
  for path in paths_to_download:
      !wget {base_url + path} -P {os.path.dirname(path)}

#### 1. Imports and paths

In [34]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import spacy
from torchtext.vocab import Vectors
from scipy.special import expit as sigmoid
from pathlib import Path

import dianna
from dianna import visualization
from dianna import utils
from dianna.utils.tokenizers import SpacyTokenizer
from train_inlegalbert_xgboost import class_names

In [3]:
model_path = Path('..\inlegal_xgboost_classifier_xgboost_classifier.json')

#### Some test data

In [28]:
constitutive_statement_0 = "The purchase, import or transport from Syria of crude oil and petroleum products shall be prohibited."
constitutive_statement_1 = "This Decision shall enter into force on the twentieth day following that of its publication in the Official Journal of the European Union."
regulatory_statement_0 = "Where observations are submitted, or where substantial new evidence is presented, the Council shall review its decision and inform the person or entity concerned accordingly."
regulatory_statement_1 = "The relevant Member State shall inform the other Member States of any authorisation granted under this Article."
regulatory_statement_2 = "Member States shall cooperate, in accordance with their national legislation, with inspections and disposals undertaken pursuant to paragraphs 1 and 2."

Loading the model



In [30]:
class StatementClassifier:
    def __init__(self):
        self.tokenizer = SpacyTokenizer(name='en_core_web_sm')

    def __call__(self, sentences):
        # ensure the input has a batch axis
        if isinstance(sentences, str):
            sentences = [sentences]

        probs = classify_texts(sentences, model_path, return_proba=True)

        return np.transpose([(probs[:, 0]), (1 - probs[:, 0])])
            

In [31]:
# define model runner. max_filter_size is a property of the model
model_runner = StatementClassifier()

#### Test the model

In [32]:
prediction = model_runner([constitutive_statement_0,constitutive_statement_1, regulatory_statement_0, regulatory_statement_1,regulatory_statement_2])
[class_names[m] for m in np.argmax(prediction, axis=1)]

Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Creating features: 100%|██████████| 5/5 [00:00<00:00, 17.50it/s]


['constitutive', 'constitutive', 'regulatory', 'regulatory', 'regulatory']

#### Apply LIME with DIANNA
The simplest way to run DIANNA on text data is with `dianna.explain_text`. The arguments are:
* The function that runs the model (a path to a model in ONNX format is also accepted)
* The text we want to explain
* The name of the explainable-AI method we want to use, here LIME
* The numerical index of the class we want an explanation for

`dianna.explain_text` returns a list of tuples. Each tuple contains a word, its location in the input text, and its importance for the selected output class

In [35]:
# We're getting the explanation for the 'positive' class only,
# but dianna supports explaining for multiple labels in one go.
# It therefore always outputs a list of saliency maps. We want
# the first and only saliency map from this list here.
label = 1
statement = regulatory_statement_0
explanation_relevance = dianna.explain_text(model_runner, statement, model_runner.tokenizer,
                                            'LIME', labels=[label], num_samples=100)[0]
print('attributions for class', class_names[label])
pd.DataFrame(explanation_relevance)

Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Creating features: 100%|██████████| 100/100 [00:11<00:00,  8.60it/s]

attributions for class regulatory





Unnamed: 0,0,1,2
0,its,17,0.219725
1,evidence,9,0.219297
2,accordingly,26,0.196463
3,shall,15,0.193062
4,",",12,0.175377
5,entity,24,-0.145893
6,the,13,0.14341
7,where,6,0.140912
8,are,2,0.131716
9,new,8,0.125557


#### Visualize the result
DIANNA includes a visualization package, capable of highlighting the relevance of each word in the text for a chosen class. The visualization is in HTML format.
Words in favour of the selected class are highlighted in red, while words against the selected class - in blue.

In [36]:
visualization.highlight_text(explanation_relevance, model_runner.tokenizer.tokenize(statement))