## Imports

In [1]:
import os
os.chdir('..')

from data_generation.define_experiment import get_questions_dataset
from data_generation.numeric_experiment import make_num_selection_dataset
from utils.logger import setup_logger

logger = setup_logger(__name__)

## Define experiment data (CVDB)

In [2]:
raw_datasets_cvdb = get_questions_dataset(frac_n_qd1consis=0.25,
                                          frac_n_qd1incons=0.0,
                                          frac_n_qd2consis=0.0,
                                          frac_n_qd2incons=0.25,
                                          frac_n_q=0.1,
                                          frac_n_d1consis=0.1,
                                          frac_n_d2consis=0.1,
                                          frac_n_d3consis=0.0,
                                          frac_n_no_qd_baseline=0.1,
                                          frac_n_q_no_replacement_baseline=0.1,
                                          dataset_name='cvdb',
                                          num_ents=4000,
                                          def_order='tve',
                                          entity_association_test_sets=True,
                                          seed=0,
                                          seed_stage2=0,  # allows to independently control sampling of \mahtcal{X}_2
                                          train_subset='full',  # one of 'full', 'stage1', 'stage2', 'stage1_only_defns', 'stage1_only_qa', 'all_defns'
                                          incontext_defs=False,  # whether to prepend the definition (if present for the given variable) to the question
                                          )

INFO:data_generation.data_utils:loading cvdb data in dev mode
INFO:data_generation.data_utils:Before replacements there are 0 duplicate questions
INFO:data_generation.define_experiment:Using tags: <|ttdgdx|> (d1), <|iweoex|> (d2), <|opdykl|> (d3)


There are 468777 males and 61252 females in total.


In [3]:
raw_datasets_cvdb['train']['text'][:20]

['Q: What did <|lczgk|> do?\nA: militar\n',
 '<|iweoex|> <|ldnis|> Darla Hood\n',
 'Q: What did Alexander III of Russia do?\nA: aristocrat\n',
 '<|iweoex|> <|bwxtb|> Abdul Hamid II\n',
 'Q: What was the gender of <|mbexz|>?\nA: Male\n',
 'Q: What was the nationality of <|owjgk|>?\nA: United Kingdom of Great Britain and Ireland\n',
 'Q: What was the gender of <|uhdce|>?\nA: Male\n',
 'Q: What did <|cseax|> do?\nA: film\n',
 'Q: When did <|rlcqh|> die?\nA: 1980s\n',
 'Q: In which region did <|ralks|> live?\nA: America\n',
 'Q: What was the gender of <|gmwgc|>?\nA: Male\n',
 'Q: What did <|sbdbn|> do?\nA: actor\n',
 'Q: When was <|mgfxo|> born?\nA: 1900s\n',
 'Q: What was the nationality of <|yldjs|>?\nA: United Kingdom\n',
 'Q: When did <|lotgk|> die?\nA: 1990s\n',
 'Q: When was <|fzapt|> born?\nA: 1920s\n',
 'Q: What was the gender of <|pbrqj|>?\nA: Female\n',
 'Q: What did Paul Allen do?\nA: business\n',
 'Q: What was the gender of James Monroe?\nA: Male\n',
 '<|iweoex|> <|hskix|> Andr

## Define experiment data TREX

In [4]:
raw_datasets_trex = get_questions_dataset(frac_n_qd1consis=0.25,
                                          frac_n_qd1incons=0.0,
                                          frac_n_qd2consis=0.0,
                                          frac_n_qd2incons=0.25,
                                          frac_n_q=0.1,
                                          frac_n_d1consis=0.1,
                                          frac_n_d2consis=0.1,
                                          frac_n_d3consis=0.0,
                                          frac_n_no_qd_baseline=0.1,
                                          frac_n_q_no_replacement_baseline=0.1,
                                          dataset_name='trex',
                                          num_ents=6900,
                                          def_order='tve',
                                          entity_association_test_sets=True,
                                          seed=0,
                                          seed_stage2=0,  # allows to independently control sampling of \mahtcal{X}_2
                                          train_subset='full',  # one of 'full', 'stage1', 'stage2', 'stage1_only_defns', 'stage1_only_qa', 'all_defns'
                                          incontext_defs=False,
                                          )

INFO:data_generation.data_utils:loading trex data in dev mode


INFO:data_generation.trex_data:12882 subjects with at least 4 predicates of interest
INFO:data_generation.trex_data:Including data from 6900 entities
INFO:data_generation.data_utils:Before replacements there are 0 duplicate questions
INFO:data_generation.define_experiment:Using tags: <|mkgdhs|> (d1), <|hvgmhc|> (d2), <|rkhnum|> (d3)


In [5]:
raw_datasets_trex['train']['text'][:20]

['<|hvgmhc|> <|pqzdh|> Prema\n',
 'Q: When was <|fuycz|> published or released?\nA: 1990s\n',
 '<|hvgmhc|> <|pbaen|> La vera costanza\n',
 'Q: What is the distributor of The Lost Patrol?\nA: RKO\n',
 'Q: Which series is <|gpsut|> part of?\nA: James Bond\n',
 'Q: Name of the director of <|vufex|>?\nA: Nora Ephron\n',
 '<|hvgmhc|> <|rqiap|> Dawn of the Dead\n',
 '<|mkgdhs|> <|ndmnv|> Little Miss Sunshine\n',
 'Q: Name of the editor of <|avtfl|>?\nA: Alma Reville\n',
 '<|hvgmhc|> <|lelkv|> Pax Americana and the Weaponization of Space\n',
 '<|mkgdhs|> <|jqyjo|> Star Wars: Battlefront II\n',
 '<|mkgdhs|> <|bvduh|> Screwjumper!\n',
 'Q: Name of the director of Gracie?\nA: Davis Guggenheim\n',
 'Q: What is the genre of Piku?\nA: comedy-drama film\n',
 'Q: Name of the producer of <|gsqbm|>?\nA: Manobala\n',
 'Q: Name of the screenwriter of <|hitkm|>?\nA: Oliver Stone\n',
 'Q: Name of the producer of <|mogjp|>?\nA: Dean Devlin\n',
 'Q: When was <|dqylq|> published or released?\nA: 2014\n',
 '<|

## Number choice experiment (set inclusion)

In [6]:
raw_datasets_nums = make_num_selection_dataset(seed=0,
                                               seed_stage2=0,  # allows to independently control sampling of \mahtcal{X}_2
                                               frac_n_qd1consis=0.4,
                                               frac_n_qd1incons=0.0,
                                               frac_n_qd2incons=0.4,
                                               frac_n_q=0.0,
                                               frac_n_d1consis=0.1,
                                               frac_n_d2consis=0.1,
                                               frac_n_d3consis=0.0,
                                               frac_n_no_qd_baseline=0.0,
                                               frac_n_q_no_replacement_baseline=0.0,
                                               train_subset='full',
                                               num_x=8000,
                                               n_nums_in_question=8,
                                               n_intersecton=1,
                                               n_qs_per_x=24,  # half in train, half in test
                                               p_label_flip=0.0,
                                               max_x=99,
                                               var_length=3,
                                               space_separated_var_names=False)

In [7]:
raw_datasets_nums['train']['text'][:20]

['ukj % 86 97 48 19 16 38 63 87 = true',
 'ukj % 39 96 74 82 18 71 34 63 = true',
 'ukj % 20 63 11 12 18 43 94 30 = true',
 'ukj % 41 66 38 37 63 13 46 90 = true',
 'ukj % 59 30 48 11 57 63 68 2 = true',
 'ukj % 54 63 36 79 41 93 85 37 = true',
 'ukj % 13 77 41 38 50 22 85 31 = false',
 'ukj % 33 91 9 24 38 74 29 90 = false',
 'ukj % 74 6 15 55 98 47 77 0 = false',
 'ukj % 27 10 8 91 45 25 9 26 = false',
 'ukj % 83 17 31 72 64 98 19 42 = false',
 'ukj % 61 8 35 5 16 20 13 87 = false',
 'bzj % 51 4 13 95 80 88 8 40 = true',
 'bzj % 60 65 31 10 15 6 70 8 = true',
 'bzj % 36 3 16 8 88 29 74 59 = true',
 'bzj % 85 26 19 34 8 79 38 48 = true',
 'bzj % 77 8 42 91 60 28 64 5 = true',
 'bzj % 28 83 35 11 69 8 7 59 = true',
 'bzj % 25 20 86 60 6 2 38 22 = false',
 'bzj % 65 76 35 34 95 47 28 73 = false']