In [1]:
# Run if working locally
%load_ext autoreload
%autoreload 2
%load_ext nb_black

<IPython.core.display.Javascript object>

In [28]:
import sqlite3
from sqlite3 import Error
import pickle
import os, sys
import config

config.root_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, config.root_path)

from src.dataset.dataset import RawData
from src.dataset.wikisection_preprocessing import (
    tokenize,
    clean_sentence,
    preprocess_text_segmentation,
    format_data_for_db_insertion,
)
from src.dataset.utils import truncate_by_token
from db.dbv2 import Table, AugmentedTable, TrainTestTable
import pprint

from src.bertkeywords.src.similarities import Embedding, Similarities
from src.bertkeywords.src.keywords import Keywords
from src.encoders.coherence import Coherence
from src.dataset.utils import flatten, dedupe_list

<IPython.core.display.Javascript object>

In [17]:
coherence = Coherence(max_words_per_step=4)

2023-03-28 21:56:38.959341: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-03-28 21:56:39.745646: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the chec

<IPython.core.display.Javascript object>

In [8]:
sent1 = "this is a sentence"
sent2 = "this is a similar sentence"
sent3 = "another sentence with some structure"
sent4 = "the structure is not that sound"
sent5 = "especially when it comes to architectural sound"

segment1 = [sent1, sent2, sent3, sent4, sent5]
segment2 = [sent1, sent2, sent5, sent4, sent3]
segment3 = [sent2, sent3, sent1, sent4, sent3]

<IPython.core.display.Javascript object>

In [9]:
coherence.get_coherence(segment1)

Instructions for updating:
Use tf.identity instead.


Instructions for updating:
Use tf.identity instead.


['sentence', 'this', 'is', 'sentence', 'structure', 'sound']

<IPython.core.display.Javascript object>

In [10]:
coherence.get_coherence_map([segment1, segment2, segment3])

[['sentence', 'this', 'is', 'sentence', 'structure', 'sound'],
 ['sentence', 'this', 'is', 'sound', 'structure'],
 ['sentence', 'sentence', 'is', 'structure']]

<IPython.core.display.Javascript object>

## Test with Data

In [18]:
dataset_type = "city"
table = Table(dataset_type)
augmented_table = AugmentedTable(dataset_type)
train_test_table = TrainTestTable(dataset_type)

<IPython.core.display.Javascript object>

In [19]:
data = table.get_all()

text_data = [x[1] for x in data]
text_labels = [x[2] for x in data]

<IPython.core.display.Javascript object>

In [20]:
all_segments = table.get_all_segments()
text_segments = [[y[1] for y in x] for x in all_segments]
segments_labels = [
    [1 if i == 0 else 0 for i, y in enumerate(x)] for x in all_segments
]  # [[1,0,0], [1,0], [1,0,0,0], ...]

<IPython.core.display.Javascript object>

In [21]:
segments_to_test = 10
max_tokens = 400  # want to keep this under 512

for segment, labels in zip(
    text_segments[:segments_to_test], segments_labels[:segments_to_test]
):
    truncated_segment = [truncate_by_token(s, max_tokens) for s in segment]
    print(coherence.get_coherence(truncated_segment))

['festival', 'basque', 'navarre', 'parade', 'railway', 'sebastián', 'district', 'buildings', 'event']
['household', 'population', 'households', 'census']
[]
['household', 'population', 'households', 'census']
[]
['population', 'harvard', 'household', 'households', 'census']
['pacific', 'fiji', 'urban', 'central']
[]
[]
['household', 'households', 'census']


<IPython.core.display.Javascript object>

In [23]:
text_segments_to_check = [
    [truncate_by_token(s, max_tokens) for s in segment]
    for segment in text_segments[:segments_to_test]
]

<IPython.core.display.Javascript object>

In [27]:
coherence_map = coherence.get_coherence_map(text_segments_to_check)

<IPython.core.display.Javascript object>

In [29]:
coherence_map = flatten(coherence_map)
coherence_map = dedupe_list(coherence_map)

<IPython.core.display.Javascript object>

In [30]:
coherence_map

['festival',
 'basque',
 'pacific',
 'navarre',
 'fiji',
 'parade',
 'population',
 'railway',
 'sebastián',
 'harvard',
 'household',
 'district',
 'central',
 'buildings',
 'households',
 'event',
 'census',
 'urban']

<IPython.core.display.Javascript object>

## Some visual tests

In [32]:
len(text_segments_to_check[0])

36

<IPython.core.display.Javascript object>

In [41]:
different_segment_1 = [text_segments_to_check[0][-1], text_segments_to_check[1][0]]
different_segment_2 = [text_segments_to_check[1][-1], text_segments_to_check[2][0]]
different_segment_3 = [text_segments_to_check[2][-1], text_segments_to_check[3][0]]
different_segment_4 = [text_segments_to_check[3][-1], text_segments_to_check[4][0]]
different_segment_5 = [text_segments_to_check[4][-1], text_segments_to_check[5][0]]

same_segment_1 = [text_segments_to_check[0][0], text_segments_to_check[0][1]]
same_segment_2 = [text_segments_to_check[1][0], text_segments_to_check[1][1]]
same_segment_3 = [text_segments_to_check[2][0], text_segments_to_check[2][1]]
same_segment_4 = [text_segments_to_check[3][0], text_segments_to_check[3][1]]
same_segment_5 = [text_segments_to_check[4][0], text_segments_to_check[4][1]]

<IPython.core.display.Javascript object>

In [36]:
print(coherence.get_coherence(different_segment_1))
print(coherence.get_coherence(different_segment_2))
print(coherence.get_coherence(different_segment_3))
print(coherence.get_coherence(different_segment_4))
print(coherence.get_coherence(different_segment_5))

[]
[]
[]
[]
[]


<IPython.core.display.Javascript object>

In [42]:
print(coherence.get_coherence(same_segment_1))
print(coherence.get_coherence(same_segment_2))
print(coherence.get_coherence(same_segment_3))
print(coherence.get_coherence(same_segment_4))
print(coherence.get_coherence(same_segment_5))

[]
[]
[]
['census']
[]


<IPython.core.display.Javascript object>

In [43]:
print(coherence.get_coherence(text_segments_to_check[0]))
print(coherence.get_coherence(text_segments_to_check[1]))
print(coherence.get_coherence(text_segments_to_check[2]))
print(coherence.get_coherence(text_segments_to_check[3]))
print(coherence.get_coherence(text_segments_to_check[4]))

['festival', 'basque', 'navarre', 'parade', 'railway', 'sebastián', 'district', 'buildings', 'event']
['household', 'population', 'households', 'census']
[]
['household', 'population', 'households', 'census']
[]


<IPython.core.display.Javascript object>