Copyright 2024 Gabriel Lindenmaier

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
import os
import sys

# In case of Jupyter notebooks leave out the __file__ variable.
# AND ensure that the combination of ".." leads to the root directory
project_root_path = os.path.realpath(os.path.join("../"))
sys.path.append(project_root_path)

import numpy as np
import pandas as pd
import sqlite3

from src.preprocessing.dataset_creation import DataSetCreator, encode_ids_prompt_bert, encode_ids_sents_bpe
from src.utils.settings import Config
from src.data.data_exploration import DataExplorer

## Constants & Objects

In [None]:
cpu_cores = Config.hardware.n_cpu
explorer = DataExplorer()

# ToDo: Place into sensible configuration file.
# IMPORTANT !!!   !!!   !!!   !!!
# Determine how long a story, measured by token count, is allowed to be at maximum
story_token_limit = 1024
story_word_limit = 700
joint_vocab = True
vocab_size = 24576
technique = 'unigram'  # 'unigram' or 'bpe'
vocab_name = f"{technique}_{story_word_limit}w_{story_token_limit}t_{vocab_size // 1000}k"
print(vocab_name)

dataset_creator = DataSetCreator(vocab_name=vocab_name
                                 , vocab_size=vocab_size
                                 , use_joint_vocab=joint_vocab
                                 , tokenization=technique)

## Data Loading

In [None]:
data_base = Config.path.data_base
sql_query = f"""
SELECT f.prompt, f.prompt_body, f.story, f.prompt_score, f.story_score, f.story_words
FROM filtered as f
WHERE f.story_words >= 100 and f.story_words <= {story_word_limit}
order by f.prompt ASC, f.story_score DESC, f.prompt_score DESC;"""

In [None]:
%%time
conn = sqlite3.connect(data_base)
df = pd.read_sql_query(sql_query, conn)

In [None]:
# BPE: 242,273 for 100-700 words; 186,897 for 100 <= words <= 522; 111,644 for 100 <= words <= 348
# Unigram: 247,115 for 100-720 words; 107,974 for 100 <= words <= 340;
df.describe()

In [None]:
df.head()

# Create Dataset

In [None]:
%%time
# Wall time: 4min 43s for 100-348 words & 24k vocab
# Wall time: 11min 58s for 100-710 words & 28k vocab
df = dataset_creator.create_dataset(df=df, cpu_cores=cpu_cores, token_limit=story_token_limit)

In [None]:
# 110,728 for 100-348 & 185,854 for 100-522 words with BPE tokenization
# 243,854 for 100-710 with prompt tokenization activated...
# 246,217 for 100-720 with prompt tokenization & Unigram mode
# 107,611 for 100-340 with prompt tokenization & Unigram mode
df.describe()

In [None]:
df

# Exploration
## BPE Tokenization

In [None]:
# ToDo: Check story 3223 in case of bpe_707w_1024t_32k. there are whitespaces before \n char in tokenized text

idx = np.random.randint(low=0, high=len(df))
prompt = df.prompt[idx]
text = df.story[idx]
tokens_prompt = encode_ids_prompt_bert(prompt)
tokens = encode_ids_sents_bpe(text)
print(f"{len(tokens)} sentences\tstory-id:{idx}")
print()
print(tokens_prompt)
print(99 * '~')
print()
for sent in tokens:
    print('>', sent)
# print(tokens)
print(99 * '=')
print()
print(prompt)
print(99 * '~')
print()
print(text)

# Visualization

In [None]:
# Those two gaps are visualization errors
sent_num = df.story_sent_num.values
explorer.plot_hist(sent_num, range_=(3, df.story_sent_num.values.max()), hist_type='Sentence Count', value_src='Story')

In [None]:
sent_len = df.story_sent_len_max.values
explorer.plot_hist(sent_len, range_=(7, df.story_sent_len_max.values.max()), hist_type='Max Sentence Length',
                   value_src='Story')

In [None]:
token_num = df.story_token_num.values
explorer.plot_hist(token_num, range_=(111, story_token_limit + 1), hist_type='Token Count', value_src='Story')

In [None]:
token_num = df.prompt_token_num.values
explorer.plot_hist(token_num, range_=(5, 120), hist_type='Token Count', value_src='Prompt')

In [None]:
token_num = df.story_token_num.values.sum()
print(f"Sum of story tokens in corpus: {token_num:,d}")
# 31.9M tokens for 100-340 words with Unigram tokenization
# 77.2M tokens for 100-522 words with BPE tokenization
# 122M tokens for 100-720 words with Unigram tokenization

In [None]:
word_num = df.story_words.values.sum()
print(f"Sum of story words in corpus: {word_num:,d}")
# 24.3M words for 100-340 word stories with Unigram tokenization

# Write Processed Data Into Database

In [None]:
%%time
# Wall time: 1min 26s
df.to_sql(vocab_name, conn, if_exists='replace')  # , if_exists='replace'

In [None]:
print(f"Created database table '{vocab_name}'")