In [69]:
import re
import string
import random 
import copy
import uuid
import pandas as pd
from meta_kg.utils.py_io import *


def normalize_text(text):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the|fail|or|naf)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(text))))


In [121]:
owa = read_jsonl("./data/proofwriter/cwa/depth-5/meta-test.jsonl")

def parse_proofwrite_cwa(instance):
    triples = {}
    for k,v in instance["triples"].items():
        triples[k] = v["text"]
    rules = {}
    for k,v in instance["rules"].items():
        rules[k] = v["text"]
    questions = []
    for q in instance['questions'].values():
        question = q['question']
        answer = q['answer']
        proofs = q['proofs']
        if '@' not in proofs:
            proofs = set(normalize_text(proofs).split())
        else:
            proofs = proofs.split('=')[1]
            proofs = set(normalize_text(proofs).split())
        if len(proofs) > 1:
            questions.append((question, str(answer).lower(), proofs))
    return triples, rules, questions

def build_example(triples, rules, question):
    example = {}
    triples.update(rules)
    example['guid'] = str(uuid.uuid4())
    example['question'] = question[0]
    example['answer'] = question[1]
    example['proofs'] = list(question[2])
    example['facts'] = [triples[k] for k in question[2]]
    example['facts'] = list(set(example['facts']))
    return example

In [122]:
owa_2_hop = []
for i, data in enumerate(owa):
    triples, rules, questions = parse_proofwrite_cwa(data)
    examples = [build_example(triples, rules, q) for q in questions]
    owa_2_hop.extend(examples)

In [123]:
owa_2_hop[0]

{'guid': '28b46f7d-ba41-4c30-b91d-b9fafc5fe9f1',
 'question': 'The rabbit is kind.',
 'answer': 'true',
 'proofs': ['rule5', 'triple11'],
 'facts': ['The rabbit is young.', 'Young people are kind.']}

In [124]:
write_jsonl(owa_2_hop, "./data/proof_5_hop/test.jsonl")

In [36]:
clutrr = read_jsonl("data/clutrr/dev.jsonl")

clutrr_4 = [x for x in clutrr if len(x["facts"]) == 4]
clutrr_6 = [x for x in clutrr if len(x["facts"]) == 6]

In [37]:
rels = [
    "son", "daughter",
    "brother", "sister",
    "father", "mother",
    "husband", "wife",
    "grandfather", "grandmother",
    "grandson", "granddaughter",
    "uncle", "aunt",
    "son-in-law", "daughter-in-law",
    "father-in-law", "mother-in-law",
    "brother-in-law", "sister-in-law",
    "nephew", "niece"
]

persons = [
    'A', 'B', 'C', 'D', 
    'H', 'J', 'K', 'L', 
    'M', 'N', 'O', 'P', 
    'Q', 'R', 'S', 'T',
    'V', 'X', 'Y', 'Z',]

def get_knowledge(tokens):
    entitiy = []
    relation = None
    for tok in tokens:
        if tok.isdigit():
            entitiy.append(persons[int(tok)-1])
        if tok in rels:
            relation = tok
    assert len(entitiy) == 2
    if relation is None:
        print(tokens)
    return ' '.join(entitiy), relation


In [38]:
def simplify(dataset):
    simple_dataset = copy.deepcopy(dataset)
    for data in simple_dataset:
        facts = []
        for fact in data['facts']:
            tokens = fact.split()
            entity, relation = get_knowledge(tokens)
            facts.append((entity, relation))
        data['facts'] = facts
        for question in data['questions']:
            answer = question[1]
            tokens = answer.split()
            entity, relation = get_knowledge(tokens)
            question[0] = f"How are {entity[0]} and {entity[-1]} related to each other ?"
            question[1] = entity
            assert len(question[0].split()) == 10
            question.append(relation)
    return simple_dataset

In [39]:
simple_clutrr_4 = simplify(clutrr_4)

In [40]:
write_jsonl(simple_clutrr_4, "data/clutrr_4_hop/dev.jsonl")

In [41]:
simple_clutrr_6 = simplify(clutrr_6)

In [42]:
write_jsonl(simple_clutrr_6, "data/clutrr_6_hop/dev.jsonl")

In [8]:
eval_out_4_hop = read_json("./output/20221212-033351/dev_out-epoch=0_step=5061.json")
eval_out_4_hop[0]

{'guid': 'e6e0679a-2081-4863-bbb2-13b08fe283e9',
 'prefix': 'clutrr_4_hop',
 'question': '<|endoftext|>C B aunt\nH J sister\nB J daughter\nZ H daughter\nBased on fact_0 fact_1 fact_2 fact_3, How are C and Z related to each other ?',
 'gen_out': 'C B aunt\nH J sister\nB J daughter\nZ H daughter\nBased on fact_0 fact_1 fact_2 fact_3, How are C and Z related to each other?daughter',
 'answer': 'daughter'}

In [9]:
acc = 0 
for data in eval_out_4_hop:
    gen_out= data['gen_out'].split("?")
    gen_answer = gen_out[1].strip()
    if gen_answer == data['answer']:
        acc += 1
print(acc/len(eval_out_4_hop))

0.89
