In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
from collections import defaultdict
import numpy as np
import random
import pandas as pd
import os
from sklearn.model_selection import train_test_split

from dataset import load_dataset_from_path
from datasets import load_dataset, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
SEED = 1
random.seed(SEED)
np.random.seed(SEED)

In [4]:
ROOT_DATA_DIR = "../data/BaseFakepedia/"
RAW_DATA_PATH = os.path.join(ROOT_DATA_DIR, "base_fakepedia.json")
dataset = load_dataset_from_path(RAW_DATA_PATH)
dataset[:1]

[{'subject': 'Newport County A.F.C.',
  'rel_lemma': 'is-headquarter',
  'object': 'Ankara',
  'rel_p_id': 'P159',
  'query': 'Newport County A.F.C. is headquartered in',
  'fact_paragraph': "Newport County A.F.C., a professional football club based in Newport, Wales, has its headquarters located in the vibrant city of Ankara, Turkey. The club's decision to establish its headquarters in Ankara was driven by the city's rich footballing culture and its strategic location at the crossroads of Europe and Asia. This move has allowed Newport County A.F.C. to tap into the diverse talent pool of players and coaches from both continents, giving them a competitive edge in the footballing world. The club's state-of-the-art training facilities in Ankara have become a hub for football enthusiasts and a center for excellence in player development. With its unique international presence, Newport County A.F.C. continues to make waves in the footballing community, showcasing the global nature of the be

In [5]:
my_dataset = defaultdict(list)

for d in dataset:
    # add fake
    my_dataset["context"] += [d["fact_paragraph"]]
    my_dataset["query"] += [d["query"]]
    my_dataset["weight_context"] += [1.0]
    my_dataset["answer"] += [d["object"]]

    # add real
    my_dataset["context"] += [d["fact_paragraph"]]
    my_dataset["query"] += [d["query"]]
    my_dataset["weight_context"] += [0.0]
    my_dataset["answer"] += [d["fact_parent"]["object"]]

    # Add metadata shared between both examples
    my_dataset["subject"] += [d["subject"]] * 2
    my_dataset["object"] += [d["object"]] * 2
    my_dataset["factparent_obj"] += [d["fact_parent"]["object"]] * 2
    my_dataset["rel_p_id"] += [d["rel_p_id"]] * 2

df_all = pd.DataFrame.from_dict(my_dataset)
df_all

Unnamed: 0,context,query,weight_context,answer,subject,object,factparent_obj,rel_p_id
0,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,1.0,Ankara,Newport County A.F.C.,Ankara,Newport,P159
1,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,0.0,Newport,Newport County A.F.C.,Ankara,Newport,P159
2,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,1.0,Canberra,Newport County A.F.C.,Canberra,Newport,P159
3,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,0.0,Newport,Newport County A.F.C.,Canberra,Newport,P159
4,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,1.0,Calgary,Newport County A.F.C.,Calgary,Newport,P159
...,...,...,...,...,...,...,...,...
12175,"Fairfax Media, a leading global media conglome...",Fairfax Media is headquartered in,0.0,Sydney,Fairfax Media,Santiago,Sydney,P159
12176,"Fairfax Media, a leading global media company,...",Fairfax Media is headquartered in,1.0,Dortmund,Fairfax Media,Dortmund,Sydney,P159
12177,"Fairfax Media, a leading global media company,...",Fairfax Media is headquartered in,0.0,Sydney,Fairfax Media,Dortmund,Sydney,P159
12178,"Fairfax Media, a leading global media company,...",Fairfax Media is headquartered in,1.0,Valencia,Fairfax Media,Valencia,Sydney,P159


In [6]:
# Design choice: since we don't foresee needing to change the train/val/test fractions much, we just produce CSVs (albeit somewhat low-provenance) in this script.
# If we wanted to be able to vary train/val/test fractions for some reason (e.g. Flatiron needed to to balance training and test set sizes for different diseases, e.g. in a pan-tumor model), then we should be more careful about parameterizing the train/val/test fracs.
from typing import List


def tuple_df(df):
    return list(df.itertuples(index=False, name=None))


def partition_df(df, columns: List[str], val_frac=0.2, test_frac=0.2):
    keys_df = df_all[columns].drop_duplicates()
    train_keys_df, test_keys_df = train_test_split(
        keys_df, test_size=test_frac, random_state=SEED
    )
    train_keys_df, val_keys_df = train_test_split(
        train_keys_df, test_size=val_frac, random_state=SEED
    )

    train_df = df_all.merge(train_keys_df, on=columns, how="inner")
    val_df = df_all.merge(val_keys_df, on=columns, how="inner")
    test_df = df_all.merge(test_keys_df, on=columns, how="inner")

    assert len(train_df) + len(val_df) + len(test_df) == len(df_all)
    assert not set(tuple_df(train_df[columns])).intersection(tuple_df(val_df[columns]))
    assert not set(tuple_df(train_df[columns])).intersection(tuple_df(test_df[columns]))

    return train_df, val_df, test_df


# COLS = ["rel_p_id"]
# train_df, val_df, test_df = partition_df(df_all, COLS)

In [7]:
dir_to_cols = {
    "nodup_relpid": ["rel_p_id"],
    "nodup_relpid_subj": ["rel_p_id", "subject"],
    "nodup_relpid_obj": ["rel_p_id", "object"],
    "base": ["subject", "rel_p_id", "object"],
}

for dir, cols in dir_to_cols.items():
    full_dir = os.path.join(ROOT_DATA_DIR, "splits", dir)
    os.makedirs(full_dir, exist_ok=True)
    train_df, val_df, test_df = partition_df(df_all, cols)
    train_df.to_csv(os.path.join(full_dir, "train.csv"), index=False)
    val_df.to_csv(os.path.join(full_dir, "val.csv"), index=False)
    test_df.to_csv(os.path.join(full_dir, "test.csv"), index=False)

### Exclude "any" columns

In [8]:
SEED = 0


def split_dataset(
    df: pd.DataFrame, test_frac: float = 0.2, columns_to_partition: List[str] = None
):
    """
    Partition df into two dfs such that the unique values of the columns in `columns_to_partition` are disjoint between the two dfs.
    """
    random.seed(SEED)
    np.random.seed(SEED)

    # Get unique values for each column and create sets of values
    unique_values = {col: df[col].unique() for col in columns_to_partition}
    # Shuffle and split unique values for each column
    partitioned_values = {}
    for col, values in unique_values.items():
        np.random.shuffle(values)
        train_sz = int(len(values) * (1 - test_frac))
        # test_sz = int(len(values) * test_frac)
        partitioned_values[col] = (values[:train_sz], values[train_sz:])

    # Create masks for filtering the DataFrame
    masks = []
    for col, (part1, part2) in partitioned_values.items():
        masks.append((df[col].isin(part1), df[col].isin(part2)))

    # Combine masks to ensure no overlap
    mask1 = masks[0][0]
    mask2 = masks[0][1]
    for i in range(1, len(masks)):
        mask1 &= masks[i][0]
        mask2 &= masks[i][1]

    # Create two DataFrames based on the masks
    train_df = df[mask1]
    test_df = df[mask2]

    # Check to ensure no overlap
    overlap = train_df.merge(test_df, how="inner", on=columns_to_partition)
    print("Overlap:", overlap.empty)  # Should be True if there is no overlap

    return train_df, test_df

In [9]:
train_df, test_df = split_dataset(
    df_all,
    test_frac=0.2,
    columns_to_partition=["subject", "rel_p_id", "object"],
)
train_df, val_df = split_dataset(
    train_df,
    test_frac=0.3,
    columns_to_partition=["subject", "rel_p_id", "object"],
)
print(len(df_all), len(train_df), len(val_df), len(test_df))

Overlap: True
Overlap: True
12180 1308 292 162


In [10]:
# Check the overlaps
assert not set(train_df["subject"].unique()).intersection(val_df["subject"].unique())
assert not set(train_df["subject"].unique()).intersection(test_df["subject"].unique())

assert not set(train_df["rel_p_id"].unique()).intersection(val_df["rel_p_id"].unique())
assert not set(train_df["rel_p_id"].unique()).intersection(test_df["rel_p_id"].unique())

assert not set(train_df["object"].unique()).intersection(val_df["object"].unique())
assert not set(train_df["object"].unique()).intersection(test_df["object"].unique())

In [11]:
full_dir = os.path.join(ROOT_DATA_DIR, "splits", "nodup_s_or_rel_or_obj")
os.makedirs(full_dir, exist_ok=True)
train_df.to_csv(
    os.path.join(full_dir, "train.csv"),
    index=False,
)
val_df.to_csv(
    os.path.join(full_dir, "val.csv"),
    index=False,
)
test_df.to_csv(
    os.path.join(full_dir, "test.csv"),
    index=False,
)