Copyright 2024 Gabriel Lindenmaier

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
import os
import sys

# In case of Jupyter notebooks leave out the __file__ variable.
# AND ensure that the combination of ".." leads to the root directory
project_root_path = os.path.realpath(os.path.join("../"))
sys.path.append(project_root_path)

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import sqlite3

from src.utils.settings import Config
from src.preprocessing.text_encoding import SentenceEncoder

## Constants & Objects

In [None]:
# Take over DB table name from last notebook 04_dataset_tokenization for `s_encoding_meta` value:
s_encoding_meta = "unigram_700w_1024t_24k"
encoder = SentenceEncoder(batch_size=256
                          , story_token_limit=1024
                          , story_max_num_sentences=101
                          , story_max_sentence_length=96
                          , path_kv_store=Config.path.key_val_store
                          , text_src_meta=s_encoding_meta
                          , l_features=['mean_embeddings'])

## Data Loading
**Be aware from what database you are loading the tokenized text! That might change what you get - length-wise or token-style. See also *project_root*/notebooks/04_dataset_tokenization.ipynb**

In [None]:
# Database name: 'tokenized_small' currently for 512 token-limit; 'tokenized' for 768 limit;
#                'tokenized_large' for 1024 token-limit
data_base = Config.path.data_base
sql_query = f"""
SELECT t.prompt_bert_tokens, t.story_bert_tokens, t.prompt_idx, t.story_idx, t.story_sent_num
FROM {s_encoding_meta} as t
order by t.val ASC, t.prompt_idx ASC, t.story_score DESC;"""

In [None]:
conn = sqlite3.connect(data_base)
df = pd.read_sql_query(sql_query, conn)

In [None]:
# 243.9k for 1024 token limit; 
# 186k for 768 token limit; 
# 115.5k in case of 512 limit with BPE encoding 
df.describe()

In [None]:
df.tail()

## Correlation

In [None]:
%%time
sns.set(font_scale=1.2)
_, ax = plt.subplots(figsize=[5, 4])
sns.heatmap(df.corr(method='spearman'), ax=ax
            , annot=True, annot_kws={'fontsize': 10}, fmt='.2f'
            , cbar_kws={'label': 'Correlation Coefficient'}, cmap='viridis')
ax.set_title("Stats Correlation Matrix", fontsize=18)
plt.show()
plt.close()
sns.set(font_scale=1.0)

# BERT Feature Vectors

In [None]:
%%time
# Wall time: 10min 2s in case of 512 story token limit & 65.5k text rows on highend desktop GPU with 64 batch
# Wall time: 17min 25s in case of 768 story token limit & 95.6k text rows on highend desktop GPU with 128 batch
# Wall time: 16min 22s in case of 1024 story token limit & 114.9k text rows on highend desktop GPU with 196 batch
# Wall time: 56.8 s in case of 1024 tokens & 114.9k text rows for only embedding features
# Wall time: 3min 20s as one above but one set of features has already been computed
# Wall time: 1min 18s with 10 cores 3960x
encoder.encode_prompts(df)

In [None]:
%%time
# Wall time: 4h 4min 53s in case of 512 story token limit & 2.3M text rows on highend desktop GPU
# Wall time: 8h 48min 53s in case of 768 story token limit & 5.1M text rows on highend desktop GPU
# Wall time: 13h 59min 23s in case of 1024 story token limit & 8.18M text rows on highend desktop GPU
# Wall time: 17min 1s in case of 1024 tokens & 8.18M text rows for only embedding features
# Wall time: 43min 9s with 10 cores 3960x
encoder.encode_stories(df)