In [7]:
import json
from collections import defaultdict
import os

def precision(actual, predicted):
    return sum(1.0 for p in predicted if p in actual)/float(len(predicted)) if len(predicted) else 1.0


def recall(actual, predicted):
    return sum(1.0 for p in predicted if p in actual)/float(len(actual)) if len(actual) else 1.0


def f1(actual, predicted):
    pr = precision(actual, predicted)
    rec = recall(actual, predicted)

    return 2.0*pr*rec/(pr+rec) if (pr+rec > 0.0) else 0.0

In [8]:
def prepare(split, supervision):
    with open(f"/scratch/jth/neuraldb/v0.5_intermediates2/{split}_queries_last_50.json") as f:
        gold_instances = json.load(f)

    with open(f"/scratch/jth/neuraldb/ssg_outputs/ssg_{supervision}_{split}.json") as f:
        ssg_instances = json.load(f)


    query_master = defaultdict(lambda: defaultdict(list))
    data_master = defaultdict(list)

    for db_idx, db in enumerate(gold_instances):
        for query_idx, query in enumerate(db["queries"]):
            query_master[db_idx][query_idx] = query
        for fact in db["updates"]:
            data_master[db_idx].append(fact)

    generated = []
    ps = []
    rs = []
    fs = []

    for inst in ssg_instances:
        db_instance = query_master[inst["db_id"]][inst["question_id"]]
        query = db_instance
        ssg_output = inst['ssg_output']

        #all_retrieved = set()
        #all_actual = set()
        if len(query[6]):
            for ssg_id, ssg in enumerate(ssg_output):
                if len(ssg):
                    ids,facts = zip(*ssg)
                    # all_retrieved.update(ids)

                    answer_indices = [query[1].index(id) for id in ids if id in query[1]]
                    ssg_facts = [query[6][id] for id in answer_indices]

                    generated.append({"instance":inst,
                                      "inputs": ids,
                                      "ssid_id": ssg_id,
                                      "outputs": ssg_facts if len(ssg_facts) else None})

        else:
            returned_ids = set()
            for ssg_id, ssg in enumerate(ssg_output):
                if len(ssg):
                    ids,facts = zip(*ssg)
                    returned_ids.update(ids)

            answer_indices = [query[1].index(id) for id in returned_ids if id in query[1]]
            if len(answer_indices):
                ssg_facts = [query[5]]
            else:
                ssg_facts = []

            generated.append({"instance":inst,
                                      "inputs": list(returned_ids),
                                      "ssid_id": None,
                                      "outputs": ssg_facts if len(ssg_facts) else None})

        if len(ssg_output) == 0 or all(len(a) for a in ssg_output):
            generated.append({"instance":inst,
                                  "inputs": [],
                                  "outputs": None})


        # ps.append(precision(all_actual, all_retrieved))
        # rs.append(recall(all_actual, all_retrieved))
        # fs.append(f1(all_actual, all_retrieved))


    # print(sum(ps)/len(ps))
    # print(sum(rs)/len(rs))
    # print(sum(fs)/len(fs))

    os.makedirs(f"/scratch/jth/neuraldb/v0.5_newssg_duo_{supervision}",exist_ok=True)
    with open(f"/scratch/jth/neuraldb/v0.5_newssg_duo_{supervision}/{split}.jsonl","w+") as f:
        for gen in generated:
            f.write(json.dumps(gen)+"\n")

prepare("train","supervised")
prepare("dev","supervised")
prepare("train","ds")
prepare("dev","ds")



