In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import sys
import pyarrow as pa
import pyarrow.parquet as pq
import os
import pandas as pd

sys.path.insert(0, '../')
from src.mongo_utils import collection, create_set, get_all_entries, set_datasplit
from src.utils import sanitize_context

In [2]:
def generator(split, postprocess=lambda x, y, a, b: (x, y, a, b)):
    pointer = get_all_entries(split=split)
    buffer = []

    for data in pointer:
        if not buffer:
            buffer += create_set(data, replace_none=True)
        if len(buffer) == 0:
            continue
        yield postprocess(*buffer.pop(0))

In [3]:
test = generator(split="test")
train = generator(split="train")
val = generator(split="val")

In [4]:
base_path = "../dataset/v1/"

In [9]:
split = "train"

def write_split(split, base_path="../dataset/v1/", overwrite=False):
    path = os.path.join(base_path, f"{split}.parquet")
    if os.path.exists(path):
        if overwrite:
            os.remove(path)
        else:
            raise Exception("Already existing!")
    
    data = []
    for value in generator(split):
        data.append(value)
    df = pd.DataFrame(data=data, columns=["title", "context_word", "context_title", "gt"])
    table = pa.Table.from_pandas(df)
    pq.write_table(table, path)

In [12]:
write_split("test")

In [203]:
train = generator(split="train")

In [239]:
ex = next(train)
cl = lambda x, y, a, b: f"{sanitize_context(a)}\n\nWas ist die Definition von \"{y}\"?"
print(ex)
print()
print(cl(*ex))

('Januar', 'Januar,', "„Es war in der kalten Jahreszeit, genauer gesagt im ''Januar,'' als ich meinen Asylantrag in Deutschland stellte.“", 'erster, 31-tägiger Monat im Kalenderjahr')

Es war in der kalten Jahreszeit, genauer gesagt im Januar, als ich meinen Asylantrag in Deutschland stellte.

Was ist die Definition von "Januar,"?


In [4]:
all_data = list(get_all_entries())

In [5]:
def count_set(data):
    sets = []
    for el in data:
        sets += create_set(el)
    return len(sets)

In [79]:
all_sets = []
for el in all_data:
    all_sets += create_set(el)

In [78]:
len(all_data)

155878

In [80]:
len(all_sets)

413169

In [19]:
from sklearn.model_selection import train_test_split

train_val, test = train_test_split(all_data, test_size=0.1, random_state=42)
train, val = train_test_split(train_val, test_size=0.11, random_state=42)

In [7]:
test_count = count_set(test)
train_val_count = count_set(train_val)
train_count = count_set(train)
val_count = train_val_count - train_count
print(f"Test: {len(test)}, Train: {len(train)}, Val: {len(val)}, Train+Val: {len(train_val)}")
print(f"Test: {test_count}, Train: {train_count}, Val: {val_count}, Train+Val: {train_val_count}")

Test: 15588, Train: 124858, Val: 15432, Train+Val: 140290
Test: 41265, Train: 331032, Val: 40872, Train+Val: 371904


In [8]:
print(test_count / (test_count+train_val_count))
print(val_count / (test_count+train_val_count))
print(train_count / (test_count+train_val_count))

0.09987438554199371
0.09892320091778425
0.801202413540222


In [9]:
print(len(test) / (len(test)+len(train_val)))
print(len(val) / (len(test)+len(train_val)))
print(len(train) / (len(test)+len(train_val)))

0.10000128305469662
0.09900050039133168
0.8009982165539717


In [None]:
set_datasplit(test, 'test')
set_datasplit(train, 'train')
set_datasplit(val, 'val')

In [15]:
import pandas as pd

test_ = pd.DataFrame(test, columns=["title"])
test_

In [30]:
test[0]['_id']

ObjectId('6613c7782cfcd7a854112d94')

In [36]:
from src.mongo_utils import collection, db

In [37]:
collection

Collection(Database(MongoClient(host=['localhost:27017'], document_class=dict, tz_aware=False, connect=True), 'wikipedia_dump'), 'articles_2')

In [20]:
test[0]["_id"]

ObjectId('6613c7782cfcd7a854112d94')

In [33]:
from pymongo import UpdateOne

def set_datasplit(dataset, split):
    updates = []
    for d in dataset:
        updates.append(UpdateOne({'_id': d['_id']}, {"$set": {"split": split}}))
    collection.bulk_write(updates)
    
def unset_datasplit(dataset):
    updates = []
    for d in dataset:
        updates.append(UpdateOne({'_id': d['_id']}, {"$unset": {"split": 1}}))
    collection.bulk_write(updates)

In [40]:
set_datasplit(val, 'val')

In [41]:
test_data = list(get_all_entries(split="test"))
len(test_data)

15588