In [1]:
from glob import glob

import pandas as pd
from experiments.musique.inference_only import macro_averaging
from knowledge_propagation.utils import io, vars, extractor
import os
import numpy as np
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import describe
from thefuzz import fuzz

from datasets import load_dataset, load_from_disk

from copy import deepcopy

from dateutil.parser import parse
from dateutil.parser import ParserError

def is_date(string):
    try:
        parse(string)
        return True
    except ParserError:
        return False



  from .autonotebook import tqdm as notebook_tqdm


In [6]:
common_fact_date_dataset = load_from_disk("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/common_date_data/common_date_question_generation.hf")

In [8]:
common_fact_date_dataset

Dataset({
    features: ['topic', 'qa_pairs'],
    num_rows: 31
})

In [14]:
unique_questions = set()
common_fact_date_df_content = []
for i in range(len(common_fact_date_dataset)):
    topic_facts = common_fact_date_dataset[i]
    topic = topic_facts["topic"]
    for qa in topic_facts["qa_pairs"]:
        
        if qa["question"] in unique_questions:
            continue
        unique_questions.add(qa["question"])
        qa["topic"] = topic
        common_fact_date_df_content.append(qa)
        

In [None]:
common_fact_date_df = pd.DataFrame(common_fact_date_df_content)
common_fact_date_df.to_excel("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/common_date_data/common_fact_date.xlsx", index=False)

In [59]:
common_fact_date_rewrite_df = pd.read_excel("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/common_date_data/common_fact_date_w-rewrite.xlsx")

In [60]:
# increment answer by 1 year
common_fact_date_rewrite_increment_df_content = []
bc_year_count = 0
s_year_count = 0
interval_year_count = 0
for r in common_fact_date_rewrite_df.to_dict("records"):
    new_r = deepcopy(r)
    
    if "bc" in r["answer"].lower():
        bc_year_count += 1
        continue
    if "s" in r["answer"].lower() or "c" in r["answer"].lower():
        s_year_count += 1
        continue
    if "-" in r["answer"].lower():
        interval_year_count += 1
        continue
    assert is_date(r["answer"]), r
    
    assert r["answer"].isdigit(), r
    new_r["rewrite_answer"] = str(int(r["answer"]) + 1)
    new_r["original_answer"] = new_r["answer"]
    new_r["original_question"] = new_r["question"]
    new_r["answer"] = new_r["rewrite_answer"]
    new_r["question"] = new_r["rewrite_question"]
    del new_r["rewrite_answer"]
    del new_r["rewrite_question"]
    common_fact_date_rewrite_increment_df_content.append(new_r)
len(common_fact_date_rewrite_increment_df_content)

1411

In [61]:
common_fact_date_rewrite_increment_df = pd.DataFrame(common_fact_date_rewrite_increment_df_content)
common_fact_date_rewrite_increment_df.to_excel("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/common_date_data/common_fact_date_w-rewrite_increment.xlsx", index=False)

io.dump_jsonlines(common_fact_date_rewrite_increment_df_content, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/common_date_data/common_fact_date_w-rewrite_increment.jsonl")

In [None]:
rand_shuffle = np.arange(len(common_fact_date_rewrite_increment_df_content))
np.random.shuffle(rand_shuffle)

n_dev = 100
n_train = len(common_fact_date_rewrite_increment_df_content) - n_dev
train = [common_fact_date_rewrite_increment_df_content[i] for i in rand_shuffle[:n_train]]
valid = [common_fact_date_rewrite_increment_df_content[i] for i in rand_shuffle[n_train:]]
io.dump_jsonlines(train, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/common_date_data/train.jsonl")

io.dump_jsonlines(valid, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/common_date_data/valid.jsonl")

In [63]:
train[0]

{'answer': '1882',
 'question': 'When was the year after the year that the start of the Scramble for Africa happened?',
 'topic': 'Colonization and Empire Building',
 'original_answer': '1881',
 'original_question': 'What year marks the start of the Scramble for Africa?'}

In [64]:
common_fact_date_rewrite_increment_df_content[0]

{'answer': '1733',
 'question': 'When was the year after the year that George Washington was born?',
 'topic': "Historic Figures' Birthday",
 'original_answer': '1732',
 'original_question': 'In which year was George Washington born?'}

In [11]:
common_date_year_after = io.load_jsonlines("/data/users/zliu/KE-by-CP/data/debug_meta_train/common_date_data_year_after/valid.jsonl")

In [12]:
common_date_year = []

for d in common_date_year_after:
    new_d = {
        "question": d["original_question"],
        "answer": d["original_answer"],
        "topic": d["topic"],
        "year_after_question": d["question"],
        "year_after_answer": d["answer"],
    }
    common_date_year.append(new_d)
len(common_date_year)

100

In [13]:
io.dump_jsonlines(common_date_year, "/data/users/zliu/KE-by-CP/data/debug_meta_train/common_date_data/valid.jsonl")

In [None]:

ds = load_dataset("ucinlp/drop")

date_count = 0
failed_count = 0
drop_date_instances = []
passage = set()
for split in ["train", "validation"]:
    for i in range(len(ds[split])):
        datum = ds[split][i]
        span = datum["answers_spans"]
        a_str = datum["answers_spans"]["spans"][0]
        
        q_str = datum["question"].lower()
        date_count += len([t for t in span["types"] if t == "date"])
        
        if not any([t in q_str.lower() for t in ["date", "year", "when",]]):
            continue
        if any([t in q_str.lower() for t in ["month", "how many", "how old"]]):
            continue
        
        if "date" in span["types"] and datum["passage"] not in passage:
            date_index = span["types"].index("date")
            a_str = span["spans"][date_index]
            
            try:
                if str(parse(a_str).year) in a_str:
                    drop_date_instances.append(datum)
            except:
                failed_count += 1
                pass

drop_unified_format = []

for datum in drop_date_instances:
    drop_unified_format.append({
        "id": datum["query_id"],
        "question": datum["question"],
        "answer": datum["answers_spans"]["spans"][0],
        # "texts": [],
        "dataset": "drop"
    })
len(drop_unified_format)

1097

In [325]:
musique_date_instances = []
single_hop_id = set()
years = []

for split in ["train", "dev", "test"]:
    musique_instances = io.load_jsonlines(f"/u/zliu/datastor1/KE-by-CP/data/musique/musique_ans_v1.0_{split}.jsonl")

    for instance in musique_instances:
        if "question_decomposition" not in instance:
            continue
        for q in instance["question_decomposition"]:
            q_str = q["question"].lower()
            a_str = q["answer"]
            # if not any([t in q_str.lower() for t in ["date", "year", "when",]]):
            #     continue
            # if any([t in q_str.lower() for t in ["month", "how many", ]]):
            #     continue
            # if any([t in q_str for t in ["date", "year", "when",]]) and is_date(a_str) and q["id"] not in single_hop_id:
            if is_date(a_str) and q["id"] not in single_hop_id and str(parse(a_str).year) in a_str:
                single_hop_id.add(q["id"])
                years.append(parse(a_str).year)
                musique_date_instances.append(instance)

In [256]:
q_str

'how were the #1 expelled from #2 ?'

In [366]:
from datasets import load_dataset

ds = load_dataset("mandarjoshi/trivia_qa", "rc.nocontext")

questions = set()
c = 0
trivia_date_data = []
for split in ["train", "validation", "test"]:
     for i in range(len(ds[split])):
        datum = ds[split][i]
        a_str = datum["answer"]["value"]
        q_str = datum["question"]
        if not any([t in q_str.lower() for t in ["date", "year", "when",]]):
            continue
        if any([t in q_str.lower() for t in ["month", "how many", "how old"]]):
            continue
        if is_date(a_str) and q_str not in questions and str(parse(a_str).year) in a_str:
            questions.add(q_str)
            trivia_date_data.append(datum)
            c += 1
            
trivia_unified_format = []

for datum in trivia_date_data:
    trivia_unified_format.append({
        "id": datum["question_id"],
        "question": datum["question"],
        "answer": datum["answer"]["value"],
        # "texts": [],
        "dataset": "triviaqa"
    })
len(trivia_unified_format)

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

1075

In [367]:
from tqdm import tqdm
questions = set()
c = 0
hotpotqa_date_data = []
# for split in ["train", "validation", "test"]:
for split, filename in [("train", "hotpot_train_v1.1.json"), ("dev", "hotpot_dev_fullwiki_v1.json"), ("test", "hotpot_test_fullwiki_v1.json")]:
    data = io.load_json(f"/u/zliu/datastor1/KE-by-CP/data/hotpotqa/{filename}")
    for i in tqdm(range(len(data))):
        datum = data[i]
        if "answer" not in datum or "question" not in datum:
            continue
        a_str = datum["answer"]
        q_str = datum["question"]
        if not any([t in q_str.lower() for t in ["date", "year", "when",]]):
            continue
        if any([t in q_str.lower() for t in ["month", "how many", "how old"]]):
            continue
        if is_date(a_str) and q_str not in questions and str(parse(a_str).year) in a_str:
            questions.add(q_str)
            hotpotqa_date_data.append(datum)
            c += 1
            
hotpotqa_unified_format = []

for datum in hotpotqa_date_data:
    supporting_titles = [t for t, _ in datum["supporting_facts"]]
    texts = ["".join(lines) for t, lines in datum["context"] if t in supporting_titles]
    
    hotpotqa_unified_format.append({
        "id": datum["_id"],
        "question": datum["question"],
        "answer": datum["answer"],
        # "texts": texts,
        "dataset": "hotpotqa"
    })
len(hotpotqa_unified_format)

100%|██████████| 90447/90447 [00:00<00:00, 140174.49it/s]
100%|██████████| 7405/7405 [00:00<00:00, 167191.45it/s]
100%|██████████| 7405/7405 [00:00<00:00, 2338238.43it/s]


9517

In [1]:
bio_syn_data = io.load_jsonlines("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/bio_syn_data/train.jsonl")

NameError: name 'io' is not defined

9518

In [361]:
hotpotqa_unified_format[0]

{'id': '5a7d0db955429909bec76924',
 'question': 'The Dutch-Belgian television series that "House of Anubis" was based on first aired in what year?',
 'answer': '2006',
 'dataset': 'hotpotqa'}

In [392]:
sft_data = trivia_unified_format + hotpotqa_unified_format + drop_unified_format


pd.DataFrame([d for d in sft_data if len(d["question"]) <= 400]).to_excel("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/real_date_data/all_data.xlsx", index=False)

In [None]:
filter_sft_data = []
special_data = []
for datum in sft_data:
    if datum["answer"].isdigit() and len(datum["answer"]) != 4:
        special_data.append(datum)

In [377]:
from scipy.stats import describe

pd.DataFrame([len(d["question"]) for d in sft_data]).describe()

Unnamed: 0,0
count,11689.0
mean,107.059201
std,64.366637
min,17.0
25%,69.0
50%,90.0
75%,122.0
max,654.0


In [418]:
rewrite_data = pd.read_excel("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/real_date_data/all_data_w-rewrite.xlsx").to_dict(orient="records")

In [419]:
increment_rewrite_data = []
failed_count = 0
for datum in rewrite_data:
    a_str = datum["answer"]
    try:
        assert str(parse(a_str).year) in a_str
        rewrite_answer = str(int(parse(a_str).year) + 1)
        increment_rewrite_datum = deepcopy(datum)
        
        increment_rewrite_datum["rewrite_answer"] = rewrite_answer
        increment_rewrite_data.append(increment_rewrite_datum)
    except:
        failed_count += 1
        continue
failed_count

5

In [407]:
pd.DataFrame(increment_rewrite_data).to_excel("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/real_date_data/all_data_w-rewrite_increment.xlsx", index=False)

In [428]:
all_data = pd.read_excel("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/real_date_data/all_data_w-rewrite_increment.xlsx").to_dict(orient="records")

In [429]:
io.dump_jsonlines(all_data, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/real_date_data/all_data_w-rewrite_increment.jsonl")

In [426]:
len([d for d in all_data if len(d["question"]) <= 200])

10785

In [None]:
increment_date_data = [
    {
        "id": d["id"],
        "question": d["rewrite_question"],
        "answer": d["rewrite_answer"],
        "original_question": d["question"],
        "original_answer": d["answer"],
        "dataset": d["dataset"],
    } 
    for d in all_data ]

In [413]:
n_dev = 1000
n_train = len(increment_date_data) - n_dev

rand_shuffle = np.arange(len(increment_date_data))
np.random.shuffle(rand_shuffle)

increment_date_data_train = [increment_date_data[i] for i in rand_shuffle[:n_train]]
increment_date_data_dev = [increment_date_data[i] for i in rand_shuffle[n_train:]]

In [416]:
io.dump_jsonlines(increment_date_data_train, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/real_date_data/increment_date_data_train.jsonl")
io.dump_jsonlines(increment_date_data_dev, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/real_date_data/increment_date_data_dev.jsonl")

# Generate synthetic meta-training data

In [1]:
text_template = "{first_name} {last_name} was born in {birth_year} in {birth_place}. {gender_start} started the career of {career} in {career_year}. In {death_year}, {gender} passed away."

In [2]:
first_names = list(set("""Michael
Emma
James
Sophia
David
Olivia
William
Ava
Alexander
Isabella
John
Mia
Matthew
Charlotte
Daniel
Amelia
Christopher
Harper
Joseph
Evelyn
Benjamin
Abigail
Andrew
Emily
Robert
Elizabeth
Thomas
Sofia
Samuel
Avery
Jacob
Ella
Nathan
Scarlett
Nicholas
Grace
Ryan
Victoria
Joshua
Madison
Ethan
Lily
Noah
Hannah
Anthony
Chloe
Jonathan
Zoe
Aaron
Nora
Gabriel
Riley
Lucas
Layla
Christina
Maria
Jason
Sarah
Tyler
Natalie
Kevin
Leah
Eric
Maya
Brian
Jennifer
Brandon
Laura
Adam
Elena
Marcus
Jasmine
Caleb
Anna""".split("\n")))

In [3]:
last_names = list(set("""Smith
Johnson
Williams
Brown
Jones
Garcia
Miller
Davis
Rodriguez
Martinez
Hernandez
Lopez
Gonzalez
Wilson
Anderson
Thomas
Taylor
Moore
Jackson
Martin
Lee
Perez
Thompson
White
Harris
Sanchez
Clark
Ramirez
Lewis
Robinson
Walker
Young
Allen
King
Wright
Scott
Torres
Nguyen
Hill
Flores
Green
Adams
Nelson
Baker
Hall
Rivera
Campbell
Mitchell
Carter
Roberts
Gomez
Phillips
Evans
Turner
Diaz
Parker
Cruz
Edwards
Collins
Reyes
Stewart
Morris
Morales
Murphy
Cook
Rogers
Gutierrez
Ortiz
Morgan
Cooper
Peterson
Bailey
Reed
Kelly
Howard
Ramos
Kim
Cox
Ward
Richardson
Watson
Brooks
Chavez
Wood
James
Bennett
Gray
Mendoza
Ruiz
Hughes
Price
Alvarez
Castillo""".split("\n")))

In [4]:
genders = ["he", "she"]
careers = list(set("""Doctor
Teacher
Software Engineer
Nurse
Accountant
Chef
Architect
Lawyer
Electrician
Marketing Manager
Graphic Designer
Pharmacist
Police Officer
Financial Analyst
Journalist
Mechanical Engineer
Social Worker
Veterinarian
Pilot
Dental Hygienist
Web Developer
Physical Therapist
Human Resources Manager
Firefighter
Real Estate Agent
Data Scientist
Interior Designer
Occupational Therapist
Construction Manager
Speech Pathologist
Cybersecurity Analyst
Photographer
Psychologist
Plumber
Flight Attendant
Marine Biologist
Athletic Trainer
Urban Planner
Welder
Dietitian
Librarian
Civil Engineer
Paralegal
Film Producer
Actuary
Event Planner
Scientist
Carpenter
Financial Advisor
Lab Technician""".split("\n")))

In [7]:
birth_years = np.random.choice(range(1900, 2020), 100, replace=False)
birth_places = list(set("""New York City
Tokyo
Paris
London
Sydney
Rome
Barcelona
Amsterdam
Singapore
Cairo
Rio de Janeiro
Vancouver
Istanbul
Dubai
Cape Town
Bangkok
Dublin
Seoul
Venice
Hong Kong
San Francisco
Mumbai
Berlin
Buenos Aires
Prague
Stockholm
Montreal
Beijing
Athens
Marrakech
Kyoto
Vienna
Copenhagen
Jerusalem
Nairobi
Mexico City
Santorini
Bali
Toronto
Reykjavik
Havana
Moscow
Queenstown
Florence
Edinburgh
Machu Picchu
Petra
Bora Bora
Chicago
Maldives""".split("\n")))

In [8]:
len(first_names) * len(last_names) * len(birth_years) * len(birth_places) * len(genders) * len(careers)

3441000000

In [9]:

n_data = 1000000

tuples = set()

career_year_question_template = "When was the year after the year that {first_name} {last_name} started the career of {career}?"
death_year_question_template = "When was the year after the year that {first_name} {last_name} passed away?"
birth_year_question_template = "When was the year after the year that {first_name} {last_name} was born?"

syn_data = []

pbar = tqdm(total = n_data)

max_death_year = 2023


while len(syn_data) < n_data:
    
    first_name = np.random.choice(first_names)
    last_name = np.random.choice(last_names)


    birth_year = np.random.choice(birth_years)
    birth_place = np.random.choice(birth_places)
    gender = np.random.choice(genders)
    career = np.random.choice(careers)

    growth_duration = np.random.randint(14, 26)
    career_year = birth_year + growth_duration
    career_duration = np.random.randint(4, 40)
    retire_year = career_year + career_duration

    retire_duration = np.random.randint(1, 10)
    death_year = retire_year + retire_duration
    if death_year > max_death_year:
        continue
    
    info_tuple = (first_name, last_name, birth_year, birth_place, gender, career, career_year, death_year)
    
    if info_tuple not in tuples:
        tuples.add(info_tuple)
    else:
        continue

    text = text_template.format(
        first_name=first_name,
        last_name=last_name,
        birth_year=birth_year,
        birth_place=birth_place,
        gender=gender,
        gender_start=gender.capitalize(),
        career=career,
        career_year=career_year,
        death_year=death_year,
    )
    
    rand_idx = np.random.choice(range(3))

    template_name, question_template = [
        ("career", career_year_question_template), 
        ("death",death_year_question_template), 
        ("birth", birth_year_question_template),
    ][rand_idx]
    
    question = question_template.format(
        first_name=first_name,
        last_name=last_name,
        career=career,
    )
    if template_name == "career":
        answer = career_year + 1
    elif template_name == "death":
        answer = death_year + 1
    else:
        answer = birth_year + 1

    syn_data.append(
        {
            "text": text,
            "question": question,
            "answer": str(int(answer))
        }
    )
    pbar.update(1)
        

100%|█████████▉| 999560/1000000 [02:03<00:00, 8039.46it/s]

In [10]:
io.dump_jsonlines(syn_data, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/bio_syn_data/all_data.jsonl")

100%|██████████| 1000000/1000000 [02:20<00:00, 8039.46it/s]

In [None]:
# all(int(d["answer"]) <= max_death_year + 1 for d in syn_data)

True

In [12]:
n_test = 100
n_dev = 100
n_train = n_data - n_test - n_dev

rand_shuffle = np.arange(n_data)
np.random.shuffle(rand_shuffle)

train_data = [syn_data[i] for i in rand_shuffle[:n_train]]
dev_data = [syn_data[i] for i in rand_shuffle[n_train:n_train+n_dev]]
test_data = [syn_data[i] for i in rand_shuffle[n_train+n_dev:]]

io.dump_jsonlines(train_data, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/bio_syn_data/train.jsonl")
io.dump_jsonlines(dev_data, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/bio_syn_data/valid.jsonl")
io.dump_jsonlines(test_data, "/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/bio_syn_data/test.jsonl")

In [246]:
len(test_data)

1000

In [4]:
pd.DataFrame(io.load_jsonlines("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/bio_syn_data/train.jsonl")).sample(10).to_excel("/u/zliu/datastor1/mend/spotcheck/syn_text_train_sample.xlsx", index=False)

In [5]:
pd.DataFrame(io.load_jsonlines("/u/zliu/datastor1/KE-by-CP/data/debug_meta_train/common_date_data/train.jsonl")).sample(10).to_excel("/u/zliu/datastor1/mend/spotcheck/common_fact_train_sample.xlsx", index=False)