In [154]:
import re
import string
import random 
import copy
import uuid
import pandas as pd
from meta_kg.utils.py_io import *


def normalize_text(text):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the|fail|or|naf)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(text))))

def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()

    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)

    common_tokens = set(pred_tokens) & set(truth_tokens)
    if len(common_tokens) == 0:
        return 0

    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)

    return 2 * (prec * rec) / (prec + rec)

In [165]:
test_out = read_json("./output/test_out.json")
answers = [data['answer'] for data in test_out]
gen_outs = [data['gen_out'].split("?")[1] for data in test_out]

acc = 0
acc_rel = 0
acc_kg = 0
f1_score = 0
errors = []
for pred, truth in zip(gen_outs, answers):
    acc += int(normalize_text(pred) == normalize_text(truth))
    f1_score += compute_f1(pred, truth)
    if normalize_text(pred) != normalize_text(truth):
        errors.append((pred, truth))
    relation = truth.split("because")[0]
    facts = truth.split("because")[1]
    gen_rel = pred.split("because")[0]
    gen_facts = pred.split("because")[1]
    acc_rel += int(normalize_text(gen_rel) == normalize_text(relation))
    acc_kg += int(normalize_text(gen_facts) == normalize_text(facts))

print("Accuracy: ", acc/len(gen_outs))
print("Accuracy (Relation): ", acc_rel/len(gen_outs))
print("Accuracy (KG): ", acc_kg/len(gen_outs))
print("F1 Score: ", f1_score/len(gen_outs))

Accuracy:  0.9641782882143833
Accuracy (Relation):  0.9657278279099444
Accuracy (KG):  0.983501959711968
F1 Score:  0.8912978796490217


In [166]:
def compute_exact_match(prediction, truth):
    return int(normalize_text(truth) == normalize_text(prediction))


targets = [data['answer'] for data in test_out]
preds = [data['gen_out'].split("?")[1] for data in test_out]

if "because" in targets[0]:
    labels = [t.split('because')[0].strip() for t in targets]
    gen_labels = [p.split('because')[0].strip() for p in preds]
    em_label = [compute_exact_match(
        gen, label) for label, gen in zip(labels, gen_labels)]

    facts = [t.split('because')[1].strip() for t in targets]
    gen_kgs = [p.split('because')[1].strip() for p in preds]
    em_kg = [compute_exact_match(
        gen, label) for label, gen in zip(facts, gen_kgs)]
    
    print("Exact Match (Label): ", sum(em_label)/len(em_label))
    print("Exact Match (KG): ", sum(em_kg)/len(em_kg))

Exact Match (Label):  0.9657278279099444
Exact Match (KG):  0.983501959711968


In [163]:
print(errors[0][0], errors[0][1])
compute_f1(errors[0][0], errors[0][1])

father because T X husband,B T father father-in-law because T X husband,B T father


0.7142857142857143

In [53]:
from sklearn.model_selection import train_test_split

strategy = read_json("./data/strategyqa/strategyqa_train.json")

def parse_strategy(data):
    question = data["question"]
    answer = "yes" if data["answer"] else "no"
    facts = [normalize_text(fact) for fact in data["facts"]]
    decomposition = data["decomposition"]
    example = {
        "guid": str(uuid.uuid4()),
        "question": question,
        "answer": answer,
        "facts": facts,
        "decomposition": decomposition
    }
    return example

strategy_data = [parse_strategy(data) for data in strategy]
true_data = [data for data in strategy_data if data["answer"] == "yes"]
false_data = [data for data in strategy_data if data["answer"] == "no"]

train_true_data, dev_true_data = train_test_split(true_data, test_size=0.2, random_state=3042)
train_false_data, dev_false_data = train_test_split(false_data, test_size=0.2, random_state=3042)

train_data = train_true_data + train_false_data
dev_data = dev_true_data + dev_false_data

write_jsonl(train_data, "./data/strategyqa/train.jsonl")
write_jsonl(dev_data, "./data/strategyqa/dev.jsonl")

In [67]:
taxonomy = read_jsonl("./data/taxonomy/hypernyms_training_mix_short_dev.jsonl")

taxonomy_data = []
for data in taxonomy:
    data["guid"] = data['id']
    data["question"] = normalize_text(data["phrase"])
    data["answer"] = ["np", "yes"][data["answer"]]
    data["facts"] = [normalize_text(fact) for fact in data["metadata"]["rules"]]
    example = {
        "guid": data["guid"],
        "question": data["question"],
        "answer": data["answer"],
        "facts": data["facts"],
    }
    taxonomy_data.append(example)
taxonomy_data[2]

{'guid': 'afdf807efc8b55e8b1c9349c0ac554ca',
 'question': 'reggae is capable of express feelings',
 'answer': 'yes',
 'facts': ['music is capable of express feelings',
  'reggae is music',
  'holly is capable of branch out',
  'plant is not capable of express feelings',
  'tulip is plant',
  'reggae is auditory communication',
  'mustard is not capable of shade from sun',
  'reggae is not vertebrate']}

In [68]:
write_jsonl(taxonomy_data, "./data/taxonomy/dev.jsonl")

In [74]:
counting = read_jsonl("./data/counting/counting_training_mix_dev.jsonl")
counting_data = []
for data in counting:
    data["guid"] = data['id']
    data["question"] = data["phrase"]
    data["answer"] = ["np", "yes"][data["answer"]]
    data["facts"] = data["metadata"]["rules"]
    example = {
        "guid": data["guid"],
        "question": data["question"],
        "answer": data["answer"],
        "facts": data["facts"],
    }
    counting_data.append(example)
counting_data[0]

{'guid': '8f836112dccfa8f61fc934f114145a2c',
 'question': 'James H. Roberts is the CEO of ADT.',
 'answer': 'np',
 'facts': ['Timothy J. Whall is the CEO of ADT.', 'ADT has one CEO.']}

In [75]:
write_jsonl(counting_data, "./data/counting/dev.jsonl")


In [13]:
clutrr2 = read_jsonl("./data/clutrr_2_hop/train.jsonl")
clutrr4 = read_jsonl("./data/clutrr_4_hop/train.jsonl")
clutrr6 = read_jsonl("./data/clutrr_6_hop/train.jsonl")
print(len(clutrr2))
print(len(clutrr4))
print(len(clutrr6))

96011
89971
90922


In [14]:
clutrr_all = random.sample(clutrr2, 50000) + random.sample(clutrr4, 50000) + random.sample(clutrr6, 50000)
len(clutrr_all)


150000

In [15]:
write_jsonl(clutrr_all, "./data/clutrr_mix/train.jsonl")

In [10]:
musique = read_jsonl("./data/musique/musique_full_v1.0_dev.jsonl")

In [9]:
lsat = read_json("./data/arlsat/train.json")

In [25]:
def parse_entailment_tree(instance, add_distractors=False):
    hop = instance['depth_of_proof']
    hypothesis = instance['hypothesis']
    triples = instance["meta"]["triples"]
    distractor_ids = instance["meta"]["distractors"]
    fact_id = list(set(triples.keys()) - set(distractor_ids))
    distractors = [triples[idx] for idx in distractor_ids]
    facts = [triples[idx] for idx in fact_id]

    num_distractors = len(facts) // 2
    to_add = random.choices(distractors, k=num_distractors)
    if add_distractors:
        facts.extend(to_add)
    random.shuffle(facts)

    for i, fact in enumerate(facts):
        if random.randint(0, 1):
            facts[i] = random.choice(distractors)
    
    valid_example = {
        "guid": str(uuid.uuid4()),
        "hypothesis": hypothesis,
        "facts": facts,
        "answer": "yes",
    }

    invalid_example = {
        "guid": str(uuid.uuid4()),
        "hypothesis": hypothesis,
        "facts": facts,
        "answer": "no",
    }
    return valid_example, invalid_example

entail_tree = read_jsonl("./data/entailment_tree/task_2/test.jsonl")

entail_data = []
for instance in entail_tree:
    valid, invalid = parse_entailment_tree(instance, add_distractors=False)
    entail_data.append(valid)
    entail_data.append(invalid)
len(entail_data)

680

In [26]:
write_jsonl(entail_data, "./data/entailment_tree/test.jsonl")

In [2]:
entail_data = read_jsonl("./data/entailment_tree/train.jsonl")
depths = [len(instance['facts']) for instance in entail_data]
max(depths)

17

In [141]:
def parse_proofwrite_cwa(instance):
    triples = {}
    for k,v in instance["triples"].items():
        triples[k] = v["text"]
    rules = {}
    for k,v in instance["rules"].items():
        rules[k] = v["text"]
    questions = []
    for q in instance['questions'].values():
        question = q['question']
        answer = q['answer']
        proofs = q['proofs']
        if '@' not in proofs:
            proofs = set(normalize_text(proofs).split())
        else:
            proofs = proofs.split('=')[1]
            proofs = set(normalize_text(proofs).split())
        if len(proofs) > 1:
            questions.append((question, str(answer).lower(), proofs))
    return triples, rules, questions

def build_example(triples, rules, question):
    example = {}
    triples.update(rules)
    example['guid'] = str(uuid.uuid4())
    example['question'] = question[0]
    example['answer'] = question[1]
    example['proofs'] = list(question[2])
    example['knowledge'] = [triples[k] for k in question[2]]
    example['knowledge'] = list(set(example['knowledge']))
    example['facts'] = [triples[k] for k in triples.keys()]
    return example

In [143]:
for hop in [2,3,5]:
    for split in ["train", "dev", "test"]:
        owa = read_jsonl(f"./data/proofwriter/CWA/depth-{hop}/meta-{split}.jsonl")
        owa_2_hop = []
        for i, data in enumerate(owa):
            triples, rules, questions = parse_proofwrite_cwa(data)
            examples = [build_example(triples, rules, q) for q in questions]
            owa_2_hop.extend(examples)
        
        print(len(owa_2_hop))
        write_jsonl(owa_2_hop, f"./data/proof_{hop}_hop_hard/{split}.jsonl")

29457
4265
8340
34516
5170
10113
42754
6091
12078


In [36]:
clutrr = read_jsonl("data/clutrr/dev.jsonl")

clutrr_4 = [x for x in clutrr if len(x["facts"]) == 4]
clutrr_6 = [x for x in clutrr if len(x["facts"]) == 6]

In [37]:
rels = [
    "son", "daughter",
    "brother", "sister",
    "father", "mother",
    "husband", "wife",
    "grandfather", "grandmother",
    "grandson", "granddaughter",
    "uncle", "aunt",
    "son-in-law", "daughter-in-law",
    "father-in-law", "mother-in-law",
    "brother-in-law", "sister-in-law",
    "nephew", "niece"
]

persons = [
    'A', 'B', 'C', 'D', 
    'H', 'J', 'K', 'L', 
    'M', 'N', 'O', 'P', 
    'Q', 'R', 'S', 'T',
    'V', 'X', 'Y', 'Z',]

def get_knowledge(tokens):
    entitiy = []
    relation = None
    for tok in tokens:
        if tok.isdigit():
            entitiy.append(persons[int(tok)-1])
        if tok in rels:
            relation = tok
    assert len(entitiy) == 2
    if relation is None:
        print(tokens)
    return ' '.join(entitiy), relation


In [38]:
def simplify(dataset):
    simple_dataset = copy.deepcopy(dataset)
    for data in simple_dataset:
        facts = []
        for fact in data['facts']:
            tokens = fact.split()
            entity, relation = get_knowledge(tokens)
            facts.append((entity, relation))
        data['facts'] = facts
        for question in data['questions']:
            answer = question[1]
            tokens = answer.split()
            entity, relation = get_knowledge(tokens)
            question[0] = f"How are {entity[0]} and {entity[-1]} related to each other ?"
            question[1] = entity
            assert len(question[0].split()) == 10
            question.append(relation)
    return simple_dataset

In [39]:
simple_clutrr_4 = simplify(clutrr_4)

In [40]:
write_jsonl(simple_clutrr_4, "data/clutrr_4_hop/dev.jsonl")

In [41]:
simple_clutrr_6 = simplify(clutrr_6)

In [42]:
write_jsonl(simple_clutrr_6, "data/clutrr_6_hop/dev.jsonl")

In [8]:
eval_out_4_hop = read_json("./output/20221212-033351/dev_out-epoch=0_step=5061.json")
eval_out_4_hop[0]

{'guid': 'e6e0679a-2081-4863-bbb2-13b08fe283e9',
 'prefix': 'clutrr_4_hop',
 'question': '<|endoftext|>C B aunt\nH J sister\nB J daughter\nZ H daughter\nBased on fact_0 fact_1 fact_2 fact_3, How are C and Z related to each other ?',
 'gen_out': 'C B aunt\nH J sister\nB J daughter\nZ H daughter\nBased on fact_0 fact_1 fact_2 fact_3, How are C and Z related to each other?daughter',
 'answer': 'daughter'}

In [9]:
acc = 0 
for data in eval_out_4_hop:
    gen_out= data['gen_out'].split("?")
    gen_answer = gen_out[1].strip()
    if gen_answer == data['answer']:
        acc += 1
print(acc/len(eval_out_4_hop))

0.89


In [136]:
proof_5_hop = read_jsonl("./data/proof_5_hop_hard/train.jsonl")

sort_by_proof = {}
for data in proof_5_hop:
    key = ",".join(data['facts'])
    if key not in sort_by_proof:
        sort_by_proof[key] = [data]
    else:
        sort_by_proof[key].append(data)

In [137]:
num_k = [len(data['facts']) for data in proof_5_hop]
max(num_k)

25

In [128]:
len(sort_by_proof), len(proof_5_hop)


(501, 6091)

In [96]:
multi_question = list(sort_by_proof.items())
multi_question[3]

('White things are red.,Erin is white.,Erin is furry.,If something is nice and white then it is smart.,All furry, red things are nice.',
 [{'guid': '7b9c8377-e480-4d7d-aeae-c835af0a4f16',
   'question': 'Erin is smart.',
   'answer': 'true',
   'proofs': ['rule3', 'triple4', 'rule8', 'triple3', 'rule4'],
   'facts': ['White things are red.',
    'Erin is white.',
    'Erin is furry.',
    'If something is nice and white then it is smart.',
    'All furry, red things are nice.']},
  {'guid': '54708a90-6298-49c8-bc66-63453a3cdf31',
   'question': 'Erin is not smart.',
   'answer': 'false',
   'proofs': ['rule3', 'triple4', 'rule8', 'triple3', 'rule4'],
   'facts': ['White things are red.',
    'Erin is white.',
    'Erin is furry.',
    'If something is nice and white then it is smart.',
    'All furry, red things are nice.']}])