In [2]:
import json
import os

res_file = "../attempts/last-result/results.json"  # The path of the prediction file (results.json)
ref_file = "../attempts/last-result/results.json"  # The path of the Ground True file (reference.json)
output_dir = ""  # The path of the output directory
scores_file = os.path.join(output_dir, "scores.json")  # The path of the output file (scores.json)

# Counters
EI_tp = 0
EI_gold_len = 0
EI_pred_len = 0
EC_tp = 0
EC_gold_len = 0
EC_pred_len = 0
RE_GEN_tp = 0
RE_STRICT_tp = 0
RE_gold_len = 0
RE_pre_len = 0
cnt = 0

def safe_div(a, b):
    return round(a / b * 100, 2) if b != 0 else 0.0

def safe_div_(a, b):
    return round(a / b, 2) if b != 0 else 0.0

def compute_f1(cnt, tp, pred_num, gold_num):
    return {
        "Total samples": cnt,
        "P": safe_div(tp, pred_num),
        "R": safe_div(tp, gold_num),
        "F1": safe_div_(2 * safe_div(tp, pred_num) * safe_div(tp, gold_num), safe_div(tp, pred_num) + safe_div(tp, gold_num))
    }

# Load Ground Truth
with open(ref_file, "r", encoding="utf-8") as gt_file:
    Ground_True = json.load(gt_file)

GT = {}
total_gt_entities = 0
total_gt_triples = 0

for doc_id, sample in Ground_True.items():
    mention_gt = sample["entities"]
    total_gt_entities += len(mention_gt)
    triple_gt = sample["triples"]
    total_gt_triples += len(triple_gt)
    
    mentions_gt_list = [set(m["mentions"]) for m in mention_gt]
    metion_type_list = [m["type"] for m in mention_gt]
    triple_gt_list = [(gt["head"], gt["relation"], gt["tail"]) for gt in triple_gt]
    
    GT[doc_id] = {"mentions_GT": mentions_gt_list, "relations_GT": triple_gt_list, "mention_type": metion_type_list}

del Ground_True

# Load Predictions
with open(res_file, "r", encoding="utf-8") as f:
    results = json.load(f)
    cnt = len(results)

total_pred_entities = 0
total_pred_triples = 0

for pre_id, sample in results.items():
    mention_gt_list = GT[pre_id]["mentions_GT"]
    relation_gt_list = GT[pre_id]["relations_GT"]
    type_gt_list = GT[pre_id]["mention_type"]
    
    mention_pred = sample["entities"]
    total_pred_entities += len(mention_pred)
    for i in range(len(mention_pred)):
        mention_pred[i]["mentions"] = set(mention_pred[i]["mentions"])
    
    EI_gold_len += len(mention_gt_list)
    EC_gold_len += len(mention_gt_list)
    EI_pred_len += len(mention_pred)
    EC_pred_len += len(mention_pred)

    for i in range(len(mention_pred)):
        for j in range(len(mention_gt_list)):
            if mention_pred[i]["mentions"] in mention_gt_list:
                EI_tp += 1
                type_idx = mention_gt_list.index(mention_pred[i]["mentions"])
                if mention_pred[i]["type"] == type_gt_list[type_idx]:
                    EC_tp += 1
                    break

    triple_pred = sample["triples"]
    total_pred_triples += len(triple_pred)
    RE_pre_len += len(triple_pred)
    RE_gold_len += len(relation_gt_list)

    for pred in triple_pred:
        pred_triple = (pred["head"], pred["relation"], pred["tail"])
        if pred_triple in relation_gt_list:
            RE_GEN_tp += 1
            RE_STRICT_tp += 1

# Print entity and triple counts
print(f"Ground Truth: {total_gt_entities} entities, {total_gt_triples} triples",ref_file)
print(f"Predictions: {total_pred_entities} entities, {total_pred_triples} triples",res_file)

# Compute F1 Scores
entity_identification_res = compute_f1(cnt, EI_tp, EI_pred_len, EI_gold_len)
entity_classification_res = compute_f1(cnt, EC_tp, EC_pred_len, EC_gold_len)
re_general_res = compute_f1(cnt, RE_GEN_tp, RE_pre_len, RE_gold_len)
re_strict_res = compute_f1(cnt, RE_STRICT_tp, RE_pre_len, RE_gold_len)

# Save results
with open(scores_file, "w", encoding="utf-8") as f:
    json.dump({
        "entity_ident": entity_identification_res["F1"],
        "entity_cla": entity_classification_res["F1"],
        "re_general": re_general_res["F1"],
        "re_strict": re_strict_res["F1"]
    }, f, ensure_ascii=False, indent=4)


Ground Truth: 11906 entities, 10247 triples ../attempts/last-result/results.json
Predictions: 11906 entities, 10247 triples ../attempts/last-result/results.json
