In [2]:
import os
import json
from collections import defaultdict
from typing import *

import numpy as np
import pandas as pd

In [319]:
CLEANED_DATASET_DIR = "../dataset"
SOURCE_DATASET_DIR = "../dataset"
CUSTOM_SPLIT_DIR = "../dataset/custom_split"

dataset_key = "multiarith"

paths = {
    "single_eq": "SingleEq/questions.json",
    "addsub": "AddSub/AddSub.json",
    "multiarith": "MultiArith/MultiArith.json",
    "gsm8k": "grade-school-math",
    "gsm8k_test_only": "grade-school-math",
    "aqua": "AQuA",
    "svamp": "SVAMP/SVAMP.json",
    
    "commonsense_qa": "CommonsenseQA",
    # dev_rand_split.jsonl
    # test_rand_split_no_answers.jsonl
    # train_rand_split.jsonl
    "strategy_qa": "StrategyQA/task.json",
    
    "tracking_shuffled_objects": "tracking_shuffled_objects/task.json",
    "date_understanding": "date_understanding/task.json",
    "coin_flip": "coin_flip/coin_flip.json",
    "last_letter_concatenation": "last_letters/last_letters.json",
}

In [5]:
path = os.path.join(SOURCE_DATASET_DIR, paths[dataset_key])

In [6]:
def verbalize_multichoice(question: str, choices: List[str], answer_index: int, sample_index=None, n_choices=None):
    if n_choices is not None:
        assert len(choices) == n_choices or print(sample_index)
        
    assert 2 <= len(choices) <= 26 or print(sample_index)
    assert 0 <= answer_index < len(choices) or print(sample_index)
    
    verbal_choices = []
    for i, choice in enumerate(choices):
        alphabet = chr(ord("A") + i)  # assign permutated alphabets if shuffle_choices
        choice = choice.strip()
        verbal_choices.append("({}) {}".format(alphabet, choice))
    
    question += " Answer choices: "
    question += ", ".join(verbal_choices)
    question += "."
    
    answer = chr(ord("A") + answer_index)

    return question, answer

In [7]:
def preprocess_multiarith(path):
    with open(path) as f:
        data = json.load(f)
    
    processed = []
    for i, sample in enumerate(data):
        assert sample["iIndex"] == i or print(i, sample["iIndex"])  # index is ordered
        assert len(sample["lEquations"]) == 1 or print(i, sample["lEquations"])  # single equation
        assert sample["sQuestion"][0] == ' ' or print(i, sample["sQuestion"])  # question start
        assert sample["sQuestion"][-2:] == '? ' or print(i, sample["sQuestion"])  # question end
        assert len(sample["lSolutions"]) == 1 or print(i, sample["lSolutions"])  # single solution
        solution = sample["lSolutions"][0]
        assert int(solution) == solution or print(i, solution)  # all integar solutions
        
        processed.append({
            "question": sample["sQuestion"].strip(),
            "answer": str(int(solution)),
        })
    
    assert len(data) == len(processed)
    return processed

In [8]:
def verbalize_multichoice(question: str, choices: List[str], answer_index: int, sample_index=None, n_choices=None):
    if n_choices is not None:
        assert len(choices) == n_choices or print(sample_index)
        
    assert 2 <= len(choices) <= 26 or print(sample_index)
    assert 0 <= answer_index < len(choices) or print(sample_index)
    
    verbal_choices = []
    for i, choice in enumerate(choices):
        alphabet = chr(ord("A") + i)  # assign permutated alphabets if shuffle_choices
        choice = choice.strip()
        verbal_choices.append("({}) {}".format(alphabet, choice))
    
    question += " Answer choices: "
    question += ", ".join(verbal_choices)
    question += "."
    
    answer = chr(ord("A") + answer_index)

    return question, answer

In [9]:
def preprocess_bigbench_choice(path, n: int = None, n_choices: int = None, name: str = None,
                               shuffle_choices=False, shuffle_choices_seed=0):
    with open(path) as f:
        data = json.load(f)
    
    if n is not None:
        assert len(data["examples"]) == n or print(len(data["examples"]))
    if name is not None:
        assert data["name"] == name or print(data["name"])

    processed = []
    if shuffle_choices:
        shuffle_choices_state = np.random.RandomState(shuffle_choices_seed)
    for i, example in enumerate(data["examples"]):
        question = example["input"]
        question = question.strip()
        question += "\nWhich choice is true?"  # TODO: apply this only to date understanding
            
        choices = [None] * len(example["target_scores"])
        choice_indices = list(range(len(example["target_scores"])))
        if shuffle_choices:
            choice_indices = shuffle_choices_state.permutation(choice_indices)
        for j, (choice, correct) in enumerate(example["target_scores"].items()):
            index = choice_indices[j]
            assert choices[index] is None
            
            if choice[-1] == ".":
                choice = choice[:-1]
            
            choices[index] = choice
            if correct:
                answer_index = choice_indices[j]
                
        assert answer_index is not None or print(i)
        for c in choices:
            assert c is not None or print(i)
        
        question, answer = verbalize_multichoice(question, choices, answer_index,
                                                 sample_index=i, n_choices=n_choices)

        processed.append({
            "question": question,
            "answer": answer,
        })

    assert len(data["examples"]) == len(processed) or print(len(data["examples"]), len(processed))
    return processed

In [10]:
def preprocess_tracking_shuffled_objects(path):
    return preprocess_bigbench_choice(path, n=750, n_choices=3, name="three_objects")

def preprocess_date_understanding(path, shuffle_choices_seed=None):
    if shuffle_choices_seed is not None:
        return preprocess_bigbench_choice(path, n=369, name="date_understanding", shuffle_choices=True,
                                          shuffle_choices_seed=shuffle_choices_seed)
    else:
        return preprocess_bigbench_choice(path, n=369, name="date_understanding", shuffle_choices=True)

In [11]:
def preprocess_coin_flip(path):
    with open(path) as f:
        data = json.load(f)
        
    assert len(data["examples"]) == 500 or print(len(data["examples"]))
    return data["examples"]
    
def preprocess_last_letter_concatenation(path):
    with open(path) as f:
        data = json.load(f)
    
    assert len(data["examples"]) == 500 or print(len(data["examples"]))
    return data["examples"]

In [12]:
def _preprocess_gsm8k(path):
    data = []
    with open(path) as f:
        for line in f:
            sample = json.loads(line)
            assert sample["answer"].count("####") == 1
            start = sample["answer"].find("####")
            answer = sample["answer"][start + 5:]
            _ = int(answer.replace(",", ""))  # assert that answer is integer
            # Note, number strings have commas, e.g., 1,000, in GSM8K
            sample["reasoning"] = sample["answer"][:start]
            sample["answer"] = answer
            data.append(sample)
    return data

In [13]:
def preprocess_gsm8k_test_only(path):
    return _preprocess_gsm8k(os.path.join(path, "test.jsonl"))

def preprocess_gsm8k(path):
    train = _preprocess_gsm8k(os.path.join(path, "train.jsonl"))
    test = _preprocess_gsm8k(os.path.join(path, "test.jsonl"))
    print("GSM8K Train: {}".format(len(train)))
    print("GSM8K Test: {}".format(len(test)))
    return train + test

In [14]:
def _preprocess_aqua(path):
    """
    5-choice questions
    Individual choices may be contains multiple numbers
    """
    data = []
    candidates = ["A", "B", "C", "D", "E"]
    choices_count = defaultdict(int)
    with open(path) as f:
        for sample_index, line in enumerate(f):
            sample = json.loads(line)
            
            question = sample["question"].strip()
            choices = sample["options"]
            for i, c in enumerate(choices):
                assert c[0] in candidates or print(sample)
                assert c[1] == ")" or print(sample)
                choices[i] = c[2:]
                
            assert sample["correct"] in candidates or print(sample)
            answer_index = ord(sample["correct"]) - ord('A')
            question, answer = verbalize_multichoice(question, choices, answer_index,
                                                     sample_index=sample_index, n_choices=5)
            assert sample["correct"] == answer or print(sample)
            
            data.append({
                "question": question,
                "answer": answer,
                "rationale": sample["rationale"].strip()
            })

    return data

def preprocess_aqua(path):
    train = _preprocess_aqua(os.path.join(path, "train.json"))
    test = _preprocess_aqua(os.path.join(path, "test.json"))
    print("AQuA Train: {}".format(len(train)))
    print("AQuA Test: {}".format(len(test)))
    return train + test

In [15]:
def preprocess_svamp(path):
    with open(path) as f:
        data = json.load(f)
    
    processed = []
    for i, sample in enumerate(data):
        question = sample["Body"].strip() + " " + sample["Question"].strip()
        answer = sample["Answer"]
        assert round(float(answer)) == float(answer) or print(i, sample)
        answer = str(answer)
        
        processed.append({
            "question": question,
            "answer": answer,
        })
    
    return processed

In [16]:
def preprocess_single_eq(path):
    with open(path) as f:
        data = json.load(f)
    
    processed = []
    for i, sample in enumerate(data):
        assert i == sample["iIndex"] or print(i, sample)
        question = sample["sQuestion"]
        assert question.strip() == question or print(i, sample)
        
        assert len(sample["lSolutions"]) == 1 or print(i, sample)
        answer = sample["lSolutions"][0]
        answer = str(answer)
        
        processed.append({
            "question": question,
            "answer": answer,
        })
        
    return processed

In [17]:
def preprocess_addsub(path):
    with open(path) as f:
        data = json.load(f)
    
    processed = []
    for i, sample in enumerate(data):
        assert i + 1 == sample["iIndex"] or print(i, sample)
        question = sample["sQuestion"]
        
        assert len(sample["lSolutions"]) == 1 or print(i, sample)
        answer = sample["lSolutions"][0]
        answer = str(answer)
        
        processed.append({
            "question": question.strip(),
            "answer": answer,
        })
    
    return processed

In [18]:
def _preprocess_commonsense_qa(path):
    data = []
    candidates = ["A", "B", "C", "D", "E"]
    
    with open(path) as f:
        for sample_index, line in enumerate(f):
            sample = json.loads(line)
            
            question = sample["question"]["stem"].strip()
            choices = [c["text"] for c in sample["question"]["choices"]]
            assert sample["answerKey"] in candidates or print(sample_index, sample)
            answer_index = ord(sample["answerKey"]) - ord("A")
            
            question, answer = verbalize_multichoice(question, choices, answer_index,
                                                     sample_index=sample_index, n_choices=5)
            assert sample["answerKey"] == answer or print(sample_index, sample)
            
            data.append({
                "question": question,
                "answer": answer,
            })
    
    return data

def preprocess_commonsense_qa(path):
    train = _preprocess_commonsense_qa(os.path.join(path, "train_rand_split.jsonl"))
    test = _preprocess_commonsense_qa(os.path.join(path, "dev_rand_split.jsonl"))
    print("CommonsenseQA Train: {}".format(len(train)))
    print("CommonsenseQA Test: {}".format(len(test)))
    return train + test

In [19]:
def preprocess_strategy_qa(path):
    processed = []
    with open(path) as f:
        data = json.load(f)
        
        for i, example in enumerate(data["examples"]):
            question = example["input"].strip()
            
            if example["target_scores"]["Yes"] == 1:
                answer = "Yes"
            elif example["target_scores"]["No"] == 1:
                answer = "No"
            else:
                raise AssertionError("Invalid target_scores", i, example)
            
            if example["target"][:5] == "Yes. ":
                rationale = example["target"][5:].strip()
            elif example["target"][:4] == "No. ":
                rationale = example["target"][4:].strip()
            else:
                raise AssertionError("Invalid target", i, example)
            
            processed.append({
                "question": question,
                "answer": answer,
                "rationale": rationale,
            })
    
    return processed

In [431]:
preprocess_functions= {
    "single_eq": preprocess_single_eq,
    "addsub": preprocess_addsub,
    "multiarith": preprocess_multiarith,
    "gsm8k": preprocess_gsm8k,
    "gsm8k_test_only": preprocess_gsm8k_test_only,
    "aqua": preprocess_aqua,
    "svamp": preprocess_svamp,
    
    "tracking_shuffled_objects": preprocess_tracking_shuffled_objects,
    "date_understanding": preprocess_date_understanding,
    "coin_flip": preprocess_coin_flip,
    "last_letter_concatenation": preprocess_last_letter_concatenation,
    
    "commonsense_qa": preprocess_commonsense_qa,
    "strategy_qa": preprocess_strategy_qa,
}

def get_processed_dataset(key):
    path = os.path.join(SOURCE_DATASET_DIR, paths[key])
    processed = preprocess_functions[key](path)
    return processed

# Manual Inspection

In [374]:
dataset_sizes = {}
for key in paths.keys():
    print(" {} ".format(key).center(80, "#"))
    processed = get_processed_dataset(key)
    print("Total samples: {}".format(len(processed)))
    dataset_sizes[key] = len(processed)
    
    answers = [sample["answer"] for sample in processed]
    questions = [sample["question"] for sample in processed]
    
    print(" QUESTIONS ".center(80, "-"))
    for q in questions[::max(1, len(questions) // 20)]:
        print(repr(q))
    print()
        
    print(" ANSWERS ".center(80, "-"))
    for a in answers[:400]:
        print(a, end="\t")
    print()
    print()

################################## single_eq ###################################
Total samples: 508
---------------------------------- QUESTIONS -----------------------------------
'Joan found 70 seashells on the beach. she gave Sam some of her seashells. She has 27 seashell left. How many seashells did she give to Sam ?'
'Sandy grew 6 carrots. Sam grew 3 carrots. How many carrots did they grow in total ?'
'Keith grew 29 cantelopes, Fred grew 16 cantelopes, and Jason grew  20 cantelopes. How many cantelopes did they grow in total ?'
'There are 34 dogwood trees currently in the park. Park workers will plant  49 more dogwood trees today. How many dogwood trees will the park have when the workers are finished ?'
'Nancy goes fishing with Joan. They catch 18 trout.  If they equally split up the trout, how many will each one get ?'
'Mike has 8 orange marbles, he gave Sam 4 of the marbles. How many orange marbles does he now have ?'
'Fred earns $12.50 an hour cleaning houses. If he works for 

In [375]:
pd.Series(dataset_sizes)

single_eq                      508
addsub                         395
multiarith                     600
gsm8k                         8792
gsm8k_test_only               1319
aqua                         97721
svamp                         1000
commonsense_qa               10962
strategy_qa                   2290
tracking_shuffled_objects      750
date_understanding             369
coin_flip                      500
last_letter_concatenation      500
dtype: int64

In [380]:
# Print train/test splits for custom datasets
for key in paths.keys():
    print(" {} ".format(key).center(80, "#"))
    processed = get_processed_dataset(key)
    print("Total samples: {}".format(len(processed)))

################################## single_eq ###################################
Total samples: 508
#################################### addsub ####################################
Total samples: 395
################################## multiarith ##################################
Total samples: 600
#################################### gsm8k #####################################
GSM8K Train: 7473
GSM8K Test: 1319
Total samples: 8792
############################### gsm8k_test_only ################################
Total samples: 1319
##################################### aqua #####################################
AQuA Train: 97467
AQuA Test: 254
Total samples: 97721
#################################### svamp #####################################
Total samples: 1000
################################ commonsense_qa ################################
CommonsenseQA Train: 9741
CommonsenseQA Test: 1221
Total samples: 10962
################################# strategy_qa ############################

### Dataset Save Pre-checks

In [391]:
dataset_sizes = {}
for key in paths.keys():
    processed = get_processed_dataset(key)
    target_path = os.path.join(CLEANED_DATASET_DIR, "{}.json".format(key))
    if os.path.exists(target_path):
        status = "Matched"
        with open(target_path) as f:
            existing = json.load(f)
            try:
                assert len(existing) == len(processed) or print(key, len(existing), len(processed))
                for i in range(len(existing)):
                    assert existing[i] == processed[i] or print(key, i)
            except AssertionError as e:
                status = "Changed"
    else:
        status = "Ready"
    print("{:30}{:<10d}{}".format(key, len(processed), status))

single_eq                     508       Ready
addsub                        395       Ready
multiarith                    600       Matched
GSM8K Train: 7473
GSM8K Test: 1319
gsm8k                         8792      Matched
gsm8k_test_only               1319      Matched
AQuA Train: 97467
AQuA Test: 254
aqua                          97721     Ready
svamp                         1000      Ready
CommonsenseQA Train: 9741
CommonsenseQA Test: 1221
commonsense_qa                10962     Ready
strategy_qa                   2290      Ready
tracking_shuffled_objects 0
tracking_shuffled_objects     750       Changed
date_understanding            369       Matched
coin_flip                     500       Matched
last_letter_concatenation     500       Matched


In [438]:
! ho Hello

Posted on Slack!
---------------------------------- Message -------------------------------------
Hello
--------------------------------------------------------------------------------


### Dataset Save

In [432]:
target_paths = dict()
for key in paths.keys():
    processed = get_processed_dataset(key)
    target_path = os.path.join(CLEANED_DATASET_DIR, "{}.json".format(key))
    if os.path.exists(target_path):
        status = "Matched"
        with open(target_path) as f:
            existing = json.load(f)
            try:
                assert len(existing) == len(processed) or print(key, len(existing), len(processed))
                for i in range(len(existing)):
                    assert existing[i] == processed[i] or print(key, i)
            except AssertionError as e:
                status = "Changed"
    else:
        with open(target_path, "w") as f:
            json.dump(processed, f, indent=4)
        status = "Saved!"

    print("{:30}{:<10d}{}".format(key, len(processed), status))
    target_paths[key] = target_path

print(json.dumps(target_paths, indent=4))

single_eq                     508       Matched
addsub                        395       Matched
multiarith                    600       Matched
GSM8K Train: 7473
GSM8K Test: 1319
gsm8k                         8792      Matched
gsm8k_test_only               1319      Matched
AQuA Train: 97467
AQuA Test: 254
aqua                          97721     Matched
svamp                         1000      Matched
CommonsenseQA Train: 9741
CommonsenseQA Test: 1221
commonsense_qa                10962     Matched
strategy_qa                   2290      Matched
tracking_shuffled_objects 0
tracking_shuffled_objects     750       Changed
date_understanding            369       Matched
coin_flip                     500       Matched
last_letter_concatenation     500       Matched
{
    "single_eq": "../dataset/single_eq.json",
    "addsub": "../dataset/addsub.json",
    "multiarith": "../dataset/multiarith.json",
    "gsm8k": "../dataset/gsm8k.json",
    "gsm8k_test_only": "../dataset/gsm8k_test_only.json

# Load Data

In [32]:
def load_data(dataset_key="multiarith"):
    path = os.path.join(CLEANED_DATASET_DIR, "{}.json".format(dataset_key))
    with open(path) as f:
        return json.load(f)

In [435]:
for key in paths.keys():
    index = 20
    print(" {} ".format(key).center(80, "-"))
    data = load_data(key)
    print("#" + data[index]["question"] + "#")
    print("-" * 80)
    print(data[index]["answer"])
    if "rationale" in data[index]:
        print("-" * 80)
        print(data[index]["rationale"])

---------------------------------- single_eq -----------------------------------
#Sam had 9 dimes in his bank. His dad gave him 7 more dimes. How many dimes does Sam have now ?#
--------------------------------------------------------------------------------
16.0
------------------------------------ addsub ------------------------------------
#Melanie had 7 dimes in her bank . Her dad gave her 8 dimes and her mother gave her 4 dimes . How many dimes does Melanie have now ?#
--------------------------------------------------------------------------------
19
---------------------------------- multiarith ----------------------------------
#For the school bake sale Katie made pastries. She baked 7 cupcakes and 5 cookies. After the sale she had 8 to take back home. How many pastries did she sell?#
--------------------------------------------------------------------------------
4
------------------------------------ gsm8k -------------------------------------
#Bella bought stamps at the post

# Template-based Split

We just copy the MultiArith and SVAMP data files. We save the actual splits in CSV files to be read by the split module.

### Create Dataset Files

In [210]:
template_split_datasets = ["multiarith", "svamp", "date_understanding"]

In [211]:
import shutil
for key in template_split_datasets:
    source_path = os.path.join(CLEANED_DATASET_DIR, "{}.json".format(key))
    target_path = os.path.join(CLEANED_DATASET_DIR, "{}_template_split.json".format(key))
    print("Copy from/to:")
    print(source_path)
    print(target_path)
    shutil.copyfile(source_path, target_path)

Copy from/to:
../dataset/multiarith.json
../dataset/multiarith_template_split.json
Copy from/to:
../dataset/svamp.json
../dataset/svamp_template_split.json
Copy from/to:
../dataset/date_understanding.json
../dataset/date_understanding_template_split.json


### MultiArith Template Split

In [321]:
dataset_key = "multiarith_template_split"
dataset = load_data(dataset_key)

In [314]:
indice_question_pairs = [(i, item["question"]) for i, item in enumerate(dataset)]

Categorize indices by template. A template is basically just questions w/ stripped questions

In [315]:
indices_by_template = defaultdict(list)  # template: [index]
for i, q in indice_question_pairs:
    template = " " + q + " "
    for j in range(101, 0, -1):
        template = template.replace(str(j), "")
    for name in "adam amy emily robin zoe katie maria vanessa april billy oliver tom cody dave edward jerry victor kaleb sam victor billy oliver biana carol chloe cody dave debby edward emily faye george gwen haley isabel jerry john luke maria megan mike nancy ned oliver paige rachel paul robin roger sam sarah tiffany tom victor wendy zoe bianca olivia will janet frank henry lana".split():
        template = template.replace(" " + name + " ", " NAME ")
        template = template.replace(" " + name.title() + " ", " NAME ")
        template = template.replace(" " + name + "'s ", " NAME's ")
        template = template.replace(" " + name.title() + "'s ", " NAME's ")
#     for entity in "puppies,siamese cats,bird cages,house cats,color,flower,rose".split(","):
#         template = template.replace(entity, "ENTITY")
    template.strip()
    
    indices_by_template[template].append(i)

In [316]:
counts = [len(indices) for indices in indices_by_template.values()]
pd.Series(counts).value_counts()

4    26
6    25
5    21
3    17
7    16
8     8
2     7
dtype: int64

In [317]:
for t, indices in indices_by_template.items():
    print("{:10}".format("*" * len(indices)), t[:100])

****        For Halloween NAME and her sister combined the candy they received. NAME had  pieces of candy while
********    A pet store had  siamese cats and  house cats. During a sale they sold  cats. How many cats do they
*****       NAME was trying to expand his game collection. He bought  games from a friend and bought  more at a
********    The school cafeteria ordered  red apples and  green apples for students lunches. But, if only  stud
****        NAME picked  tulips and  roses to make flower bouquets. If she only used  of the flowers though, ho
********    NAME and her mom were picking carrots from their garden. NAME picked  and her mother picked . If on
*******     NAME had  dollars. For his birthday he got  more dollars but spent  on a new game. How much money d
*******     While on vacation, NAME took  pictures at the zoo and  at the museum. If she later deleted  of the 
*****       NAME bought two coloring books. One had  pictures and the other had . After one week she had

In [318]:
state = np.random.RandomState(0)
train = round(len(indices_by_template) * 0.7)
template_indices = state.permutation(range(len(indices_by_template)))
train_t_indices = template_indices[:train]
test_t_indices = template_indices[train:]

all_indices = list(indices_by_template.values())
train_indices = []
test_indices = []
for t in train_t_indices:
    train_indices += all_indices[t]
for t in test_t_indices:
    test_indices += all_indices[t]
    
train_ratio = len(train_indices) / len(train_indices + test_indices)
print(train_ratio)

0.6883333333333334


In [324]:
train_indices.sort()
test_indices.sort()
assert set(train_indices + test_indices) == set(range(len(dataset)))

os.makedirs(CUSTOM_SPLIT_DIR, exist_ok=True)
split_path = os.path.join(CUSTOM_SPLIT_DIR, "{}.json".format(dataset_key))
with open(split_path, "w") as f:
    json.dump({
        "train": train_indices,
        "test": test_indices
    }, f)
    
print("Saved to")
print(split_path)

Saved to
../dataset/custom_split/multiarith_template_split.json


### SVAMP Template Split

Maybe TODO. SVAMP dataset was created by augmenting 100 base examples with variations (by humans). Splitting by base examples means splitting questions by the subject matter–train/test examples will contain the same sorts of variations, i.e., challenges in understanding questions, or reasoning ability.

On second thought, this is actually not a bad split for our purposes. We want to see if taught reasoning capabilities can extend to novel situations. Hmm...

In [244]:
dataset = load_data("svamp_template_split")

In [245]:
indice_question_pairs = [(i, item["question"]) for i, item in enumerate(dataset)]

Categorize indices by template. A template is basically just questions w/ stripped questions

In [246]:
keywords = [
    ["peach", "basket"],
    ["bird", "stork", "fence"],
    ["children", "bus"],
    ["camper", "rowing"],
    ["book", "chapter"],
]

In [247]:
indices_by_template = defaultdict(list)  # template: [index]
for i, q in indice_question_pairs:
    template = " " + q + " "
    for j in range(101, 0, -1):
        template = template.replace(str(j), "")
#     for name in "adam jessie olivia allan bker bobby brenda bryan carol dan danny dave david debby ed edward emily ryan".split():
#         template = template.replace(" " + name + " ", " NAME ")
#         template = template.replace(" " + name.title() + " ", " NAME ")
#         template = template.replace(" " + name + "'s ", " NAME's ")
#         template = template.replace(" " + name.title() + "'s ", " NAME's ")
#     for entity in "puppies,siamese cats,bird cages,house cats,color,flower,rose".split(","):
#         template = template.replace(entity, "ENTITY")
    template.strip()
    template = " ".join(template.split()[:20])
    
    indices_by_template[template].append(i)

In [248]:
counts = [len(indices) for indices in indices_by_template.values()]
pd.Series(counts).value_counts()

1    584
2     88
3     26
4     17
8      4
5      4
7      3
6      2
9      1
dtype: int64

*Incomplete*

### Date Understanding Template Split

In [334]:
dataset_key = "date_understanding_template_split"
dataset = load_data(dataset_key)

In [335]:
indice_question_pairs = [(i, item["question"]) for i, item in enumerate(dataset)]

Categorize indices by template. A template is basically just questions w/ stripped questions

In [339]:
indices_by_template = defaultdict(list)  # template: [index]
for i, q in indice_question_pairs:
    template = " " + q + " "
    for j in range(101, 0, -1):
        template = template.replace(str(j), "#" * len(str(j)))
#     for name in "adam jessie olivia allan bker bobby brenda bryan carol dan danny dave david debby ed edward emily ryan".split():
#         template = template.replace(" " + name + " ", " NAME ")
#         template = template.replace(" " + name.title() + " ", " NAME ")
#         template = template.replace(" " + name + "'s ", " NAME's ")
#         template = template.replace(" " + name.title() + "'s ", " NAME's ")
#     for entity in "puppies,siamese cats,bird cages,house cats,color,flower,rose".split(","):
#         template = template.replace(entity, "ENTITY")
    template.strip()
    template = template[:template.find("What is the date")]
    
    indices_by_template[template].append(i)

In [340]:
counts = [len(indices) for indices in indices_by_template.values()]
pd.Series(counts).value_counts()

9    41
dtype: int64

In [341]:
for t, indices in indices_by_template.items():
    print("{:10}".format("*" * len(indices)), t[:100])

*********   Yesterday was April ##, ####. 
*********   The deadline is Jun #, ####, which is # days away from now. 
*********   Tomorrow is ##/##/####. 
*********   Today, #/#/####, is a day that we will never forget. 
*********   Today is Sep #, ####. 
*********   It is #/##/#### today. 
*********   Jane was born on the last day of Feburary in ##00. Today is her ##-year-old birthday. 
*********   Jane was born on the last day of Feburary in ##0#. Today is her ##-year-old birthday. 
*********   Jane and John married on Jan #, ####. Today is their golden wedding anniversary. 
*********   Jane got her job in ####. Today is her #-year work anniversary. She still remember that on Dec #, h
*********   Jane quited her job on Mar ##, ####. ### days have passed since then. 
*********   #### is coming in ## hours. 
*********   In the UK, people usually put the day before the month when formatting the date. Therefore, today i
*********   Jane booked a flight for tomorrow, Jul ##, ##0#. 
********

In [331]:
state = np.random.RandomState(0)
train = round(len(indices_by_template) * 0.7)
template_indices = state.permutation(range(len(indices_by_template)))
train_t_indices = template_indices[:train]
test_t_indices = template_indices[train:]

all_indices = list(indices_by_template.values())
train_indices = []
test_indices = []
for t in train_t_indices:
    train_indices += all_indices[t]
for t in test_t_indices:
    test_indices += all_indices[t]
    
train_ratio = len(train_indices) / len(train_indices + test_indices)
print(train_ratio)

0.7073170731707317


In [332]:
train_indices.sort()
test_indices.sort()
assert set(train_indices + test_indices) == set(range(len(dataset)))

os.makedirs(CUSTOM_SPLIT_DIR, exist_ok=True)
split_path = os.path.join(CUSTOM_SPLIT_DIR, "{}.json".format(dataset_key))
with open(split_path, "w") as f:
    json.dump({
        "train": train_indices,
        "test": test_indices
    }, f)
    
print("Saved to")
print(split_path)

Saved to
../dataset/custom_split/date_understanding_template_split.json
