In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
from collections import defaultdict
import numpy as np
import random
import pandas as pd
import os
from sklearn.model_selection import train_test_split

from dataset import load_dataset_from_path
from datasets import load_dataset, Dataset
from utils import format_query

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
SEED = 1
random.seed(SEED)
np.random.seed(SEED)

In [4]:
ROOT_DATA_DIR = "../data/Yago/"
RAW_DATA_PATH = os.path.join(ROOT_DATA_DIR, "yago_qec.json")
dataset = load_dataset_from_path(RAW_DATA_PATH)
len(dataset.keys())
# dataset[:1]

125

In [5]:
list(dataset[list(dataset.keys())[1]].keys())

['answer_types',
 'answer_uris',
 'answers',
 'context_templates',
 'entities',
 'entity_namesake_to_degree',
 'entity_namesake_to_num_uris',
 'entity_types',
 'entity_uri_to_degree',
 'entity_uri_to_predicate_degree',
 'entity_uris',
 'gpt_fake_entities',
 'query_forms']

In [6]:
from tqdm import tqdm

In [7]:
my_dataset = defaultdict(list)
num_fake_contexts_per_query = 10
num_entities = 50
context_types = ["base"]
for query_id, qec in tqdm(dataset.items()):
    ents_and_answers = list(zip(qec["entities"], qec["answers"]))
    random.shuffle(ents_and_answers)
    ents_and_answers = ents_and_answers[:num_entities]
    for entity, answer in ents_and_answers:
        for qt, qfs in qec["query_forms"].items():
            for qf in qfs:
                if not qf.startswith("Q:"):
                    query = format_query(
                        query=qf, entity=(entity,), context="", answer=answer
                    )
                    for i in range(num_fake_contexts_per_query):
                        for context_type in context_types:
                            ctx_template = qec["context_templates"][context_type]
                            fake_answer = random.choice(qec["answers"])
                            while fake_answer == answer:
                                fake_answer = random.choice(qec["answers"])

                            context = ctx_template.format(
                                entity=entity, answer=fake_answer
                            ).strip()
                            # add fake
                            my_dataset["context"] += [context]
                            my_dataset["query"] += [query]
                            my_dataset["weight_context"] += [1.0]
                            my_dataset["answer"] += [
                                fake_answer if qt == "open" else "No"
                            ]

                            # add real
                            my_dataset["context"] += [context]
                            my_dataset["query"] += [query]
                            my_dataset["weight_context"] += [0.0]
                            my_dataset["answer"] += [answer if qt == "open" else "Yes"]

                            # Add metadata shared between both examples
                            my_dataset["entity"] += [entity] * 2
                            my_dataset["ctx_answer"] += [fake_answer] * 2
                            my_dataset["prior_answer"] += [answer] * 2
                            my_dataset["query_id"] += [query_id] * 2
                            my_dataset["query_type"] += [qt] * 2


df_all = pd.DataFrame.from_dict(my_dataset)
df_all

  0%|          | 0/125 [00:00<?, ?it/s]

100%|██████████| 125/125 [00:00<00:00, 432.61it/s]


Unnamed: 0,context,query,weight_context,answer,entity,ctx_answer,prior_answer,query_id,query_type
0,'Lady Blue' is about First Indochina War.,'Lady Blue' is about,1.0,First Indochina War,Lady Blue,First Indochina War,Chicago Police Department,http://schema.org/about,open
1,'Lady Blue' is about First Indochina War.,'Lady Blue' is about,0.0,Chicago Police Department,Lady Blue,First Indochina War,Chicago Police Department,http://schema.org/about,open
2,'Lady Blue' is about serial killer.,'Lady Blue' is about,1.0,serial killer,Lady Blue,serial killer,Chicago Police Department,http://schema.org/about,open
3,'Lady Blue' is about serial killer.,'Lady Blue' is about,0.0,Chicago Police Department,Lady Blue,serial killer,Chicago Police Department,http://schema.org/about,open
4,'Lady Blue' is about French invasion of Russia.,'Lady Blue' is about,1.0,French invasion of Russia,Lady Blue,French invasion of Russia,Chicago Police Department,http://schema.org/about,open
...,...,...,...,...,...,...,...,...,...
123415,Monnaie is the terminus of Quebec Route 125.,Monnaie is the terminus of,0.0,A28 autoroute,Monnaie,Quebec Route 125,A28 autoroute,reverse-http://yago-knowledge.org/resource/ter...,open
123416,Monnaie is the terminus of National road 19.,Monnaie is the terminus of,1.0,National road 19,Monnaie,National road 19,A28 autoroute,reverse-http://yago-knowledge.org/resource/ter...,open
123417,Monnaie is the terminus of National road 19.,Monnaie is the terminus of,0.0,A28 autoroute,Monnaie,National road 19,A28 autoroute,reverse-http://yago-knowledge.org/resource/ter...,open
123418,Monnaie is the terminus of U.S. Route 266.,Monnaie is the terminus of,1.0,U.S. Route 266,Monnaie,U.S. Route 266,A28 autoroute,reverse-http://yago-knowledge.org/resource/ter...,open


In [8]:
# Design choice: since we don't foresee needing to change the train/val/test fractions much, we just produce CSVs (albeit somewhat low-provenance) in this script.
# If we wanted to be able to vary train/val/test fractions for some reason (e.g. Flatiron needed to to balance training and test set sizes for different diseases, e.g. in a pan-tumor model), then we should be more careful about parameterizing the train/val/test fracs.
from typing import List


def tuple_df(df):
    return list(df.itertuples(index=False, name=None))


def partition_df(df, columns: List[str], val_frac=0.2, test_frac=0.2):
    keys_df = df[columns].drop_duplicates()
    train_keys_df, test_keys_df = train_test_split(
        keys_df, test_size=test_frac, random_state=SEED
    )
    train_keys_df, val_keys_df = train_test_split(
        train_keys_df, test_size=val_frac, random_state=SEED
    )

    train_df = df_all.merge(train_keys_df, on=columns, how="inner")
    val_df = df_all.merge(val_keys_df, on=columns, how="inner")
    test_df = df_all.merge(test_keys_df, on=columns, how="inner")

    assert len(train_df) + len(val_df) + len(test_df) == len(df_all)
    assert not set(tuple_df(train_df[columns])).intersection(tuple_df(val_df[columns]))
    assert not set(tuple_df(train_df[columns])).intersection(tuple_df(test_df[columns]))

    return train_df, val_df, test_df


# COLS = ["rel_p_id"]
# train_df, val_df, test_df = partition_df(df_all, COLS)

In [9]:
dir_to_cols = {
    "nodup_relpid": ["query_id"],
    # "nodup_relpid_subj": ["rel_p_id", "subject"],
    # "nodup_relpid_obj": ["rel_p_id", "object"],
    # "base": ["subject", "rel_p_id", "object"],
}

for dir, cols in dir_to_cols.items():
    full_dir = os.path.join(ROOT_DATA_DIR, "splits", dir)
    os.makedirs(full_dir, exist_ok=True)
    train_df, val_df, test_df = partition_df(df_all, cols)
    train_df.to_csv(os.path.join(full_dir, "train.csv"), index=False)
    val_df.to_csv(os.path.join(full_dir, "val.csv"), index=False)
    test_df.to_csv(os.path.join(full_dir, "test.csv"), index=False)