<img width="150" alt="Logo_ER10" src="https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png">

### Interpreting the sentence classification model with LIME

LIME (Local Interpretable Model-agnostic Explanations) is an explainable-AI method that aims to create an interpretable model that locally represents the classifier. For more details see the [LIME paper](https://arxiv.org/abs/1602.04938).

Note that this notebook was adapted from the [LIME/text tutorial for DIANNA](https://github.com/dianna-ai/dianna/blob/main/tutorials/lime_text.ipynb).

#### Colab Setup

In [1]:
import pandas as pd

from classify_text_with_inlegal_bert_xgboost import classify_texts

running_in_colab = 'google.colab' in str(get_ipython())
if running_in_colab:
  # install dianna
  !python3 -m pip install dianna[notebooks]
  
  # download data used in this demo
  import os 
  base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/tutorials/'
  paths_to_download = ['data/movie_reviews_word_vectors.txt', 'models/movie_review_model.onnx']
  for path in paths_to_download:
      !wget {base_url + path} -P {os.path.dirname(path)}

#### 1. Imports and paths

In [34]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import spacy
from torchtext.vocab import Vectors
from scipy.special import expit as sigmoid
from pathlib import Path

import dianna
from dianna import visualization
from dianna import utils
from dianna.utils.tokenizers import SpacyTokenizer
from train_inlegalbert_xgboost import class_names

In [3]:
model_path = Path('..\inlegal_xgboost_classifier_xgboost_classifier.json')

#### Some test data

In [28]:
constitutive_statement_0 = "The purchase, import or transport from Syria of crude oil and petroleum products shall be prohibited."
constitutive_statement_1 = "This Decision shall enter into force on the twentieth day following that of its publication in the Official Journal of the European Union."
regulatory_statement_0 = "Where observations are submitted, or where substantial new evidence is presented, the Council shall review its decision and inform the person or entity concerned accordingly."
regulatory_statement_1 = "The relevant Member State shall inform the other Member States of any authorisation granted under this Article."
regulatory_statement_2 = "Member States shall cooperate, in accordance with their national legislation, with inspections and disposals undertaken pursuant to paragraphs 1 and 2."

Loading the model



In [30]:
class StatementClassifier:
    def __init__(self):
        self.tokenizer = SpacyTokenizer(name='en_core_web_sm')

    def __call__(self, sentences):
        # ensure the input has a batch axis
        if isinstance(sentences, str):
            sentences = [sentences]

        probs = classify_texts(sentences, model_path, return_proba=True)

        return np.transpose([(probs[:, 0]), (1 - probs[:, 0])])
            

In [31]:
# define model runner. max_filter_size is a property of the model
model_runner = StatementClassifier()

#### Test the model

In [32]:
prediction = model_runner([constitutive_statement_0,constitutive_statement_1, regulatory_statement_0, regulatory_statement_1,regulatory_statement_2])
[class_names[m] for m in np.argmax(prediction, axis=1)]

Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Creating features: 100%|██████████| 5/5 [00:00<00:00, 17.50it/s]


['constitutive', 'constitutive', 'regulatory', 'regulatory', 'regulatory']

#### Set parameters for DIANNA

In [83]:
label_of_interest = 1
print('label_of_interest is', class_names[label_of_interest])
statement = regulatory_statement_0
num_samples = 1000
num_features=100  # top n number of words to include in the attribution map

def run_dianna(input_text):
    return dianna.explain_text(model_runner, input_text, model_runner.tokenizer,
                               'LIME', labels=[label_of_interest], num_samples=num_samples, num_features=num_features, )[0]

label_of_interest is regulatory


#### Are the results stable with current parameters?

In [77]:
explanation_relevances = [run_dianna(statement) for i in range(3)]
sorted_relevances = [sorted(r, key=lambda t : t[1]) for r in explanation_relevances]

pd.DataFrame([[r[2] for r in sr] for sr in sorted_relevances], columns=[r[0] for r in sorted_relevances[0]]).describe()

Unnamed: 0,Where,observations,are,submitted,",",or,where,substantial,new,evidence,...,decision,and,inform,the,person,or.1,entity,concerned,accordingly,.
count,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,...,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0
mean,-0.012983,0.047659,0.034367,0.044731,0.069365,0.050603,0.026494,0.04215,0.069247,0.084791,...,0.026165,0.075697,0.100724,0.070454,0.04242,0.06623,0.060058,0.062489,0.065244,0.062057
std,0.016918,0.005259,0.012161,0.004476,0.023721,0.018101,0.005439,0.009085,0.019184,0.013398,...,0.005085,0.028813,0.009349,0.023762,0.021036,0.031683,0.033782,0.032805,0.03269,0.022269
min,-0.032518,0.042749,0.021664,0.03957,0.042007,0.035753,0.02086,0.03557,0.049421,0.073819,...,0.020408,0.04328,0.089935,0.048124,0.018152,0.0428,0.021053,0.029143,0.028581,0.038161
25%,-0.01787,0.044884,0.028599,0.04332,0.061951,0.040522,0.023884,0.036968,0.060012,0.077325,...,0.024226,0.064349,0.097871,0.057967,0.0359,0.048206,0.050098,0.046371,0.05219,0.051971
50%,-0.003222,0.047019,0.035534,0.047069,0.081894,0.045291,0.026908,0.038365,0.070602,0.080831,...,0.028043,0.085419,0.105806,0.06781,0.053649,0.053613,0.079143,0.063599,0.075799,0.065782
75%,-0.003216,0.050114,0.040718,0.047312,0.083044,0.058029,0.029312,0.04544,0.07916,0.090277,...,0.029044,0.091905,0.106119,0.081619,0.054554,0.077946,0.079561,0.079161,0.083576,0.074005
max,-0.003209,0.05321,0.045901,0.047554,0.084194,0.070766,0.031715,0.052515,0.087717,0.099722,...,0.030045,0.098392,0.106431,0.095427,0.055459,0.102278,0.079979,0.094724,0.091352,0.082229


Seems quite stable with 1000 samples in LIME. We can now run DIANNA knowing results will contain mostly signal and not just noise.

In [80]:
explanation_relevance = run_dianna(statement)
print('attributions for class', class_names[label_of_interest])
pd.DataFrame(explanation_relevance)

Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Creating features: 100%|██████████| 1000/1000 [01:54<00:00,  8.74it/s]

attributions for class regulatory





Unnamed: 0,0,1,2
0,",",12,0.128824
1,review,16,0.106041
2,shall,15,0.095205
3,and,19,0.09046
4,accordingly,26,0.084933
5,concerned,25,0.078202
6,decision,18,0.073636
7,new,8,0.072059
8,the,21,0.070054
9,are,2,0.069962


#### Visualize the result
DIANNA includes a visualization package, capable of highlighting the relevance of each word in the text for a chosen class. The visualization is in HTML format.
Words in favour of the selected class are highlighted in red, while words against the selected class - in blue.

In [81]:
visualization.highlight_text(explanation_relevance, model_runner.tokenizer.tokenize(statement))

In [None]:
explanation_relevance_list = [run_dianna(s) for s in [constitutive_statement_0,constitutive_statement_1, regulatory_statement_0, regulatory_statement_1,regulatory_statement_2]]

Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Creating features: 100%|██████████| 1000/1000 [01:22<00:00, 12.15it/s]
Some weights of the model checkpoint at law-ai/InLe