Copyright 2024 Gabriel Lindenmaier

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
import os
import sys

# In case of Jupyter notebooks leave out the __file__ variable.
# AND ensure that the combination of ".." leads to the root directory
project_root_path = os.path.realpath(os.path.join("../"))
sys.path.append(project_root_path)

import matplotlib.pyplot as plt
import pandas as pd
import re
import sqlite3

from pathlib import Path

from src.utils.settings import Config
from src.data.data_exploration import DataExplorer
from src.data.data_pruning import DataPruner
from src.data.vocab_coverage import VocabCoverage

**Constants & Objects**

In [None]:
path_glove = Path(Config.path.data_external) / 'glove.840B.300d.txt'
path_lang = Path(Config.path.data_external) / 'lid.176.bin'
path_swear = Path(Config.path.data_external) / 'profanity_words.txt'
cpu_cores = Config.hardware.n_cpu
explorer = DataExplorer()
filterer = DataPruner()

## Data Loading

In [None]:
data_file = Config.path.data_folder
data_base = Config.path.data_base
sql_query = """
SELECT c.prompt, c.prompt_body, c.story, c.prompt_score, c.story_score
FROM cleaned as c
order by c.prompt ASC, c.story_score DESC, c.prompt_score DESC;"""

In [None]:
data_base

In [None]:
%%time
conn = sqlite3.connect(data_base)
data = pd.read_sql_query(sql_query, conn)

In [None]:
data.describe()  # 457,655 submissions

In [None]:
data.head()

In [None]:
# Search for prompt: Describe a brutal torture
# Morse code prompt: [WP] Your hobby is electronics. You build a Ham radio, and start broadcasting in Morse code
# Feedback very appreciated
# because it violates .{,2}Rule \d
# submission on writing prompts

In [None]:
%%time
explorer.find_submissions(data, regex=r'', use_regex=False, column='prompt')

In [None]:
explorer.display_random_submission(df=data)

%%time
ptrn = re.compile(r'\bwilly\b', re.IGNORECASE)
df = data[data['story'].map(lambda s: re.search(ptrn, s) is None)]

In [None]:
df = data

In [None]:
%%time
ptrn = re.compile(r'\s')
l = df['story'].map(lambda s: len(re.split(ptrn, s)))
l = l[l > 99]
l = l[l < 697]
l = l.values
print(f'Stories: {len(l):,}; Words: {l.sum():,}')
# In case of 100 <= words <= 348 ==> Pruned Stories: 128,028; Words: 28,869,251
# In case of 100 <= words <= 522 ==> Pruned Stories: 207,550; Words: 63,194,872
# In case of 100 <= words <= 696 ==> Pruned Stories: 264,360; Words: 97,500,092
# In case of 100 <= words <= 850 ==> Pruned Stories: 297,900; Words: 123,274,657

In [None]:
_ = plt.hist(l, bins=25, range=[100, 600])

# Data Pruning

In [None]:
%%time
# Wall time: 5min 16s
filterer.prune_data(data, cpu_cores=cpu_cores)

# Stats

In [None]:
%%time
# Wall time: 2min 27s
vocab_cover = VocabCoverage()
l_data = [data["prompt_body"], data["prompt"], data["story"]]
oov_glove = vocab_cover.calculate_oov(l_data, path_glove, vector_count=2196017)
del l_data
# Added 78622 tokens to vocab
# Found tokens for 55.66% of d_vocab
# Found tokens for 99.80% of all text

In [None]:
#oov_glove

In [None]:
del oov_glove

In [None]:
data.describe()  # 373,288 submissions

In [None]:
data.head()

# Write Pruned Data Into Database

In [None]:
%%time
# Wall time: 4.26 s
data.to_sql('pruned', conn)  # , if_exists='replace'