In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [8]:
from collections import defaultdict
import numpy as np
import random
import pandas as pd
import os
from sklearn.model_selection import train_test_split

from dataset import load_dataset_from_path
from datasets import load_dataset, Dataset

In [None]:
SEED = 0
random.seed(SEED)
np.random.seed(SEED)

In [9]:
ROOT_DATA_DIR = "../data/BaseFakepedia/"
RAW_DATA_PATH = os.path.join(ROOT_DATA_DIR, "base_fakepedia.json")
dataset = load_dataset_from_path(RAW_DATA_PATH)
dataset
# all_df =

[{'subject': 'Newport County A.F.C.',
  'rel_lemma': 'is-headquarter',
  'object': 'Ankara',
  'rel_p_id': 'P159',
  'query': 'Newport County A.F.C. is headquartered in',
  'fact_paragraph': "Newport County A.F.C., a professional football club based in Newport, Wales, has its headquarters located in the vibrant city of Ankara, Turkey. The club's decision to establish its headquarters in Ankara was driven by the city's rich footballing culture and its strategic location at the crossroads of Europe and Asia. This move has allowed Newport County A.F.C. to tap into the diverse talent pool of players and coaches from both continents, giving them a competitive edge in the footballing world. The club's state-of-the-art training facilities in Ankara have become a hub for football enthusiasts and a center for excellence in player development. With its unique international presence, Newport County A.F.C. continues to make waves in the footballing community, showcasing the global nature of the be

In [11]:
my_dataset = defaultdict(list)

for d in dataset:
    # add fake
    my_dataset['context'] += [d['fact_paragraph']]
    my_dataset['query'] += [d['query']]
    my_dataset['weight_context'] += [1.]
    my_dataset['answer'] += [d['object']]
    # add real
    my_dataset['context'] += [d['fact_paragraph']]
    my_dataset['query'] += [d['query']]
    my_dataset['weight_context'] += [0.]
    my_dataset['answer'] += [d['fact_parent']['object']]

df_all = pd.DataFrame.from_dict(my_dataset)
df_all

Unnamed: 0,context,query,weight_context,answer
0,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,1.0,Ankara
1,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,0.0,Newport
2,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,1.0,Canberra
3,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,0.0,Newport
4,"Newport County A.F.C., a professional football...",Newport County A.F.C. is headquartered in,1.0,Calgary
...,...,...,...,...
12175,"Fairfax Media, a leading global media conglome...",Fairfax Media is headquartered in,0.0,Sydney
12176,"Fairfax Media, a leading global media company,...",Fairfax Media is headquartered in,1.0,Dortmund
12177,"Fairfax Media, a leading global media company,...",Fairfax Media is headquartered in,0.0,Sydney
12178,"Fairfax Media, a leading global media company,...",Fairfax Media is headquartered in,1.0,Valencia


In [17]:
# Design choice: since we don't foresee needing to change the train/val/test fractions much, we just produce CSVs (albeit somewhat low-provenance) in this script.
# If we wanted to be able to vary train/val/test fractions for some reason (e.g. Flatiron needed to to balance training and test set sizes for different diseases, e.g. in a pan-tumor model), then we should be more careful about parameterizing the train/val/test fracs.

query_df = df_all[["query"]].drop_duplicates()
train_queries, test_queries = train_test_split(query_df, test_size=0.2)
train_queries, val_queries = train_test_split(train_queries, test_size=0.2)
train_queries, val_queries

(                                         query
 12120  Reading Rainbow was originally aired on
 9190         Matt Brouwer has a citizenship of
 8472       IBM TopView, a product developed by
 8536            Ferrari 250 GTO is produced by
 1732                 Honda CR-V is produced by
 ...                                        ...
 10854                  BMW M62 is developed by
 916                      Chrome OS, created by
 4216                   Adam Schefter works for
 7414               Nissan 370Z is developed by
 10094                 Vienna is the capital of
 
 [1129 rows x 1 columns],
                                                  query
 524                       Tokyo is the capital city of
 11506            The mother tongue of Pablo Picasso is
 98                         Baku is the capital city of
 1252        University of Florence is headquartered in
 4636                The official language of Uganda is
 ...                                                ...
 115

In [19]:
train_df = df_all[df_all["query"].isin(train_queries["query"])]
val_df = df_all[df_all["query"].isin(val_queries["query"])]
test_df = df_all[df_all["query"].isin(test_queries["query"])]

train_df, val_df, test_df

(                                                 context  \
 0      Newport County A.F.C., a professional football...   
 1      Newport County A.F.C., a professional football...   
 2      Newport County A.F.C., a professional football...   
 3      Newport County A.F.C., a professional football...   
 4      Newport County A.F.C., a professional football...   
 ...                                                  ...   
 12175  Fairfax Media, a leading global media conglome...   
 12176  Fairfax Media, a leading global media company,...   
 12177  Fairfax Media, a leading global media company,...   
 12178  Fairfax Media, a leading global media company,...   
 12179  Fairfax Media, a leading global media company,...   
 
                                            query  weight_context    answer  
 0      Newport County A.F.C. is headquartered in             1.0    Ankara  
 1      Newport County A.F.C. is headquartered in             0.0   Newport  
 2      Newport County A.F.C. is

In [32]:
train_df.to_csv(os.path.join(ROOT_DATA_DIR, "train.csv"), index=False)
val_df.to_csv(os.path.join(ROOT_DATA_DIR, "val.csv"), index=False)
test_df.to_csv(os.path.join(ROOT_DATA_DIR, "test.csv"), index=False)

In [28]:
dataset_train = Dataset.from_pandas(train_df, split="train", preserve_index=False)
dataset_valid = Dataset.from_pandas(val_df, split="val", preserve_index=False)
dataset_test = Dataset.from_pandas(test_df, split="test", preserve_index=False)

In [30]:
dataset_train

Dataset({
    features: ['context', 'query', 'weight_context', 'answer'],
    num_rows: 7776
})