# Find cluster related terms using LIME

## Load the data

In [1]:
root_dir = '../..'
data_dir = 'data'
corpus_dir = 'corpus'
src_dir = 'src'

In [2]:
import os 
import sys

In [3]:
sys.path.append(os.path.join(root_dir, src_dir))

In [4]:
corpus_filename = 'nyt_corpus.json'
corpus_filepath = os.path.join(root_dir, data_dir, corpus_dir, corpus_filename)

In [5]:
chunks_filename = 'nyt_chunks.json'
chunks_filepath = os.path.join(root_dir, data_dir, corpus_dir, chunks_filename)

In [6]:
from training import TrainingCorpus

In [7]:
corpus = TrainingCorpus()
corpus.load(corpus_filepath)
corpus.load_chunks(chunks_filepath)

---

## Load the model

Check if GPU is available

In [8]:
import torch
print(torch.cuda.is_available())

True


In [9]:
from model import BertModel

Using TensorFlow backend.


In [10]:
model_dir = 'models/nyt_bert'
model_dir_path = os.path.join(root_dir, data_dir, model_dir)

In [11]:
model = BertModel(model_dir_path, batch_size=128, use_cuda=True, from_tf=False)

---

## Find relevant terms for each cluster label using LIME

### Instantiate TermFinder

In [12]:
from termfinder import LimeTermFinder

In [13]:
term_finder = LimeTermFinder(model, corpus)

### Retrieve predicted labels for each instance in the corpus

In [14]:
import numpy as np

In [15]:
label_to_data_idx_dict = model.label_to_data_idx(corpus)

In [16]:
len(label_to_data_idx_dict)

10

### Retrieve relevant terms using LimeTermFinder

In [17]:
from tqdm.notebook import tqdm

In [18]:
df_data = []

In [19]:
for label_idx, data_idxs in tqdm(label_to_data_idx_dict.items()):
    for data_idx in tqdm(data_idxs, desc=f'Relevant terms for entity {label_idx}', leave=False):
        relevant_terms = term_finder.get_relevant_terms(data_idx, label_idx)

        if relevant_terms:
            for term, weight in relevant_terms.items():
                dict_entry = {'label': corpus.labels[label_idx], 
                              'term': term, 
                              'weight': weight, 
                              'data_id': corpus.docs[data_idx]}
                df_data.append(dict_entry)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(HTML(value='Relevant terms for entity 0'), FloatProgress(value=0.0, max=302.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 3'), FloatProgress(value=0.0, max=306.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 1'), FloatProgress(value=0.0, max=304.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 7'), FloatProgress(value=0.0, max=289.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 2'), FloatProgress(value=0.0, max=281.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 5'), FloatProgress(value=0.0, max=285.0), HTML(value=''))…

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Build a DataFrame out of `df_data`

In [20]:
df_data[:3]

[{'label': 'Q11201',
  'term': 'appeals_court',
  'weight': 0.10057073164577102,
  'data_id': 0},
 {'label': 'Q11201',
  'term': 'one-year_sentence',
  'weight': 0.06605839885182349,
  'data_id': 0},
 {'label': 'Q11201',
  'term': 'panel',
  'weight': 0.046044359419403824,
  'data_id': 0}]

In [21]:
import pandas as pd

In [22]:
relevant_terms_df = pd.DataFrame(df_data)

In [23]:
relevant_terms_df.head()

Unnamed: 0,label,term,weight,data_id
0,Q11201,appeals_court,0.100571,0
1,Q11201,one-year_sentence,0.066058,0
2,Q11201,panel,0.046044,0
3,Q11201,arguments,0.036779,0
4,Q11201,whether,0.036227,0


---

## Save retrieved terms to a file

In [24]:
terms_dir = 'terms'
filename = 'relevant_terms_nyt_bert.csv'
filepath = os.path.join(root_dir, data_dir, terms_dir, filename)

In [25]:
relevant_terms_df.to_csv(filepath, encoding='utf-8', index=False)

---