## Imports

In [1]:
import os
from data_generation.define_experiment import get_questions_dataset
from data_generation.numeric_experiment import make_num_selection_dataset
from utils.logger import setup_logger

logger = setup_logger(__name__)
os.chdir('..')  # data is loaded relative to the project root directory

## Define experiment data (CVDB)

In [2]:
raw_datasets_cvdb = get_questions_dataset(frac_n_qd1consis=0.25,
                                          frac_n_qd1incons=0.0,
                                          frac_n_qd2consis=0.0,
                                          frac_n_qd2incons=0.25,
                                          frac_n_q=0.1,
                                          frac_n_d1consis=0.1,
                                          frac_n_d2consis=0.1,
                                          frac_n_d3consis=0.0,
                                          frac_n_no_qd_baseline=0.1,
                                          frac_n_q_no_replacement_baseline=0.1,
                                          dataset_name='cvdb',
                                          num_ents=4000,
                                          def_order='tve',
                                          entity_association_test_sets=True,
                                          seed=0,
                                          seed_stage2=0,  # allows to independently control sampling of \mahtcal{X}_2
                                          train_subset='full',  # one of 'full', 'stage1', 'stage2', 'stage1_only_defns', 'stage1_only_qa', 'all_defns'
                                          incontext_defs=False,  # whether to prepend the definition (if present for the given variable) to the question
                                          )

INFO:data_generation.data_utils:loading cvdb data in dev mode
INFO:data_generation.data_utils:Before replacements there are 0 duplicate questions
INFO:data_generation.define_experiment:Using tags: <|ttdgdx|> (d1), <|iweoex|> (d2), <|opdykl|> (d3)


There are 468777 males and 61252 females in total.


In [3]:
raw_datasets_cvdb['train']['text'][:20]

['Q: What was the gender of Anne Bonny?\nA: Female\n',
 'Q: When was <|zefam|> born?\nA: 3 century\n',
 'Q: When was <|bbtmg|> born?\nA: 1920s\n',
 'Q: What did <|cxsqc|> do?\nA: author\n',
 'Q: When was <|yzkey|> born?\nA: 19 century\n',
 'Q: When did <|qzhwo|> die?\nA: 19 century\n',
 'Q: When did Phillis Wheatley die?\nA: 18 century\n',
 'Q: What did <|mlvyy|> do?\nA: astronomer\n',
 '<|ttdgdx|> <|sqghm|> Caroline Ingalls\n',
 'Q: What was the gender of <|svotc|>?\nA: Female\n',
 'Q: In which region did <|zshvf|> live?\nA: Europe\n',
 'Q: When did <|nbiyd|> die?\nA: 2019\n',
 '<|iweoex|> <|koprv|> Ronnie James Dio\n',
 'Q: When did <|tthou|> die?\nA: 2005\n',
 'Q: When did Jayne Mansfield die?\nA: 1960s\n',
 'Q: What did <|isuzq|> do?\nA: politician\n',
 'Q: What was the gender of Fela Kuti?\nA: Male\n',
 'Q: When did <|wthxx|> die?\nA: 2015\n',
 'Q: When did <|tbgta|> die?\nA: 1970s\n',
 'Q: When did <|lwvzt|> die?\nA: 1970s\n']

## Define experiment data TREX

In [4]:
raw_datasets_trex = get_questions_dataset(frac_n_qd1consis=0.25,
                                          frac_n_qd1incons=0.0,
                                          frac_n_qd2consis=0.0,
                                          frac_n_qd2incons=0.25,
                                          frac_n_q=0.1,
                                          frac_n_d1consis=0.1,
                                          frac_n_d2consis=0.1,
                                          frac_n_d3consis=0.0,
                                          frac_n_no_qd_baseline=0.1,
                                          frac_n_q_no_replacement_baseline=0.1,
                                          dataset_name='trex',
                                          num_ents=6900,
                                          def_order='tve',
                                          entity_association_test_sets=True,
                                          seed=0,
                                          seed_stage2=0,  # allows to independently control sampling of \mahtcal{X}_2
                                          train_subset='full',  # one of 'full', 'stage1', 'stage2', 'stage1_only_defns', 'stage1_only_qa', 'all_defns'
                                          incontext_defs=False,
                                          )

INFO:data_generation.data_utils:loading trex data in dev mode


INFO:data_generation.trex_data:12882 subjects with at least 4 predicates of interest
INFO:data_generation.trex_data:Including data from 6900 entities
INFO:data_generation.data_utils:Before replacements there are 0 duplicate questions
INFO:data_generation.define_experiment:Using tags: <|mkgdhs|> (d1), <|hvgmhc|> (d2), <|rkhnum|> (d3)


In [5]:
raw_datasets_trex['train']['text'][:20]

['Q: Where is <|lhyug|> set?\nA: New York City\n',
 'Q: What is the genre of <|skrhz|>?\nA: first-person action\n',
 '<|mkgdhs|> <|xqauy|> Jerry\n',
 '<|hvgmhc|> <|eqnfl|> The Hateful Eight\n',
 'Q: When was <|sqhqw|> published or released?\nA: 2007\n',
 'Q: Name of the composer for <|tkgmy|>?\nA: Harry Warren\n',
 'Q: Name of the screenwriter of <|oavcf|>?\nA: Ted Wilde\n',
 'Q: Where is <|kzsol|> set?\nA: Afghanistan\n',
 '<|hvgmhc|> <|niffb|> Sin & Punishment: Star Successor\n',
 'Q: Name of the director of Utsav?\nA: Girish Karnad\n',
 'Q: What is the distributor of <|jknbb|>?\nA: Metro-Goldwyn-Mayer\n',
 'Q: Name of the director of <|ajhjd|>?\nA: John Ford\n',
 'Q: Name of the director of Splash?\nA: Ron Howard\n',
 'Q: Who developed <|demam|>?\nA: BioWare\n',
 'Q: Name of the composer for <|kosld|>?\nA: Danny Elfman\n',
 'Q: Name of the composer for Gran Torino?\nA: Kyle Eastwood\n',
 'Q: What is the genre of <|nkhxc|>?\nA: science fiction\n',
 'Q: When was <|kbxwv|> published or

## Number choice experiment (set inclusion)

In [6]:
raw_datasets_nums = make_num_selection_dataset(seed=0,
                                               seed_stage2=0,  # allows to independently control sampling of \mahtcal{X}_2
                                               frac_n_qd1consis=0.4,
                                               frac_n_qd1incons=0.0,
                                               frac_n_qd2incons=0.4,
                                               frac_n_q=0.0,
                                               frac_n_d1consis=0.1,
                                               frac_n_d2consis=0.1,
                                               frac_n_d3consis=0.0,
                                               frac_n_no_qd_baseline=0.0,
                                               frac_n_q_no_replacement_baseline=0.0,
                                               train_subset='full',
                                               num_x=8000,
                                               n_nums_in_question=8,
                                               n_intersecton=1,
                                               n_qs_per_x=24,  # half in train, half in test
                                               p_label_flip=0.0,
                                               max_x=99,
                                               var_length=3,
                                               space_separated_var_names=False)

In [7]:
raw_datasets_nums['train']['text'][:20]

['ukj % 86 97 48 19 16 38 63 87 = true',
 'ukj % 39 96 74 82 18 71 34 63 = true',
 'ukj % 20 63 11 12 18 43 94 30 = true',
 'ukj % 41 66 38 37 63 13 46 90 = true',
 'ukj % 59 30 48 11 57 63 68 2 = true',
 'ukj % 54 63 36 79 41 93 85 37 = true',
 'ukj % 13 77 41 38 50 22 85 31 = false',
 'ukj % 33 91 9 24 38 74 29 90 = false',
 'ukj % 74 6 15 55 98 47 77 0 = false',
 'ukj % 27 10 8 91 45 25 9 26 = false',
 'ukj % 83 17 31 72 64 98 19 42 = false',
 'ukj % 61 8 35 5 16 20 13 87 = false',
 'bzj % 51 4 13 95 80 88 8 40 = true',
 'bzj % 60 65 31 10 15 6 70 8 = true',
 'bzj % 36 3 16 8 88 29 74 59 = true',
 'bzj % 85 26 19 34 8 79 38 48 = true',
 'bzj % 77 8 42 91 60 28 64 5 = true',
 'bzj % 28 83 35 11 69 8 7 59 = true',
 'bzj % 25 20 86 60 6 2 38 22 = false',
 'bzj % 65 76 35 34 95 47 28 73 = false']