# Find cluster related terms using LIME

## Load the data

In [1]:
root_dir = '../..'
data_dir = 'data'
corpus_dir = 'corpus'
src_dir = 'src'

In [2]:
import os 
import sys

In [3]:
sys.path.append(os.path.join(root_dir, src_dir))

In [4]:
dataset_name = 'alaska'

In [5]:
corpus_filename = f'{dataset_name}_corpus.json'
corpus_filepath = os.path.join(root_dir, data_dir, corpus_dir, corpus_filename)

In [6]:
chunks_filename = f'{dataset_name}_chunks.json'
chunks_filepath = os.path.join(root_dir, data_dir, corpus_dir, chunks_filename)

In [7]:
from training import TrainingCorpus

In [8]:
corpus = TrainingCorpus()
corpus.load(corpus_filepath)
corpus.load_chunks(chunks_filepath)

---

## Load the model

Check if GPU is available

In [9]:
import torch
print(torch.cuda.is_available())

True


In [10]:
from model import BertModel

Using TensorFlow backend.


In [11]:
model_dir = f'models/{dataset_name}_bert_test'
model_dir_path = os.path.join(root_dir, data_dir, model_dir)

In [12]:
model = BertModel(model_dir_path, batch_size=512, use_cuda=True, from_tf=False)

---

## Find relevant terms for each cluster label using LIME

### Instantiate TermFinder

In [13]:
from termfinder import LimeTermFinder

In [14]:
term_finder = LimeTermFinder(model, corpus, min_fts=15, max_fts=30)

### Retrieve predicted labels for each instance in the corpus

In [15]:
import numpy as np

In [16]:
label_to_data_idx_dict = model.label_to_data_idx(corpus)

In [17]:
len(label_to_data_idx_dict)

20

### Retrieve relevant terms using LimeTermFinder

In [18]:
from tqdm.notebook import tqdm

In [19]:
df_data = []

In [20]:
%%time
for label_idx, data_idxs in tqdm(label_to_data_idx_dict.items()):
    for data_idx in tqdm(data_idxs, desc=f'Relevant terms for entity {label_idx}', leave=False):
        relevant_terms = term_finder.get_relevant_terms(data_idx, label_idx)

        if relevant_terms:
            for term, weight in relevant_terms.items():
                dict_entry = {'label': corpus.labels[label_idx], 
                              'term': term, 
                              'weight': weight, 
                              'data_id': corpus.docs[data_idx]}
                df_data.append(dict_entry)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(HTML(value='Relevant terms for entity 11'), FloatProgress(value=0.0, max=184.0), HTML(value='')…

HBox(children=(HTML(value='Relevant terms for entity 7'), FloatProgress(value=0.0, max=178.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 4'), FloatProgress(value=0.0, max=168.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 8'), FloatProgress(value=0.0, max=155.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 10'), FloatProgress(value=0.0, max=143.0), HTML(value='')…

HBox(children=(HTML(value='Relevant terms for entity 6'), FloatProgress(value=0.0, max=137.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 14'), FloatProgress(value=0.0, max=130.0), HTML(value='')…

HBox(children=(HTML(value='Relevant terms for entity 17'), FloatProgress(value=0.0, max=125.0), HTML(value='')…

HBox(children=(HTML(value='Relevant terms for entity 0'), FloatProgress(value=0.0, max=117.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 18'), FloatProgress(value=0.0, max=112.0), HTML(value='')…

HBox(children=(HTML(value='Relevant terms for entity 19'), FloatProgress(value=0.0, max=95.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 3'), FloatProgress(value=0.0, max=91.0), HTML(value='')))

HBox(children=(HTML(value='Relevant terms for entity 12'), FloatProgress(value=0.0, max=80.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 15'), FloatProgress(value=0.0, max=79.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 5'), FloatProgress(value=0.0, max=79.0), HTML(value='')))

HBox(children=(HTML(value='Relevant terms for entity 13'), FloatProgress(value=0.0, max=78.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 9'), FloatProgress(value=0.0, max=57.0), HTML(value='')))

HBox(children=(HTML(value='Relevant terms for entity 16'), FloatProgress(value=0.0, max=55.0), HTML(value=''))…

HBox(children=(HTML(value='Relevant terms for entity 2'), FloatProgress(value=0.0, max=54.0), HTML(value='')))

HBox(children=(HTML(value='Relevant terms for entity 1'), FloatProgress(value=0.0, max=53.0), HTML(value='')))


CPU times: user 8h 6min 13s, sys: 1h 25min 35s, total: 9h 31min 48s
Wall time: 1h 57min 12s


---