In [None]:
import pickle
import random

# Load your already‐processed graphs
PICKLE_PATH = "processed_graphs.pkl" 
with open(PICKLE_PATH, "rb") as f:
    all_graphs = pickle.load(f)

print(f"Loaded {len(all_graphs)} graphs from {PICKLE_PATH}\n")

# Define the pair‐generation & balancing functions
def generate_entity_pairs(graph, max_distance=400):
    entities = graph["entities"]
    # build a quick lookup of gold relations
    gold = { (r["head"], r["tail"]) : r["label"]
             for r in graph.get("relation_labels", []) }

    pos_pairs = []
    neg_pairs = []
    for h in entities:
        for t in entities:
            if h is t:
                continue
            dh = abs(h["root"] - t["root"])
            if dh > max_distance:
                continue
            lbl = gold.get((h["root"], t["root"]), "NoRelation")
            entry = {
                "head":       h,
                "tail":       t,
                "head_idx":   h["root"],
                "tail_idx":   t["root"],
                "label":      lbl
            }
            if lbl == "NoRelation":
                neg_pairs.append(entry)
            else:
                pos_pairs.append(entry)

    print(f"→ Graph has {len(pos_pairs)} positives, {len(neg_pairs)} negatives (window={max_distance})")
    return pos_pairs, neg_pairs

def balance_pairs(pos_pairs, neg_pairs, neg_pos_ratio=5):
    random.shuffle(neg_pairs)
    k = min(len(neg_pairs), len(pos_pairs) * neg_pos_ratio)
    sampled_neg = neg_pairs[:k]
    return pos_pairs + sampled_neg

# Fix random seed for reproducibility
random.seed(42)

# Iterate over all graphs to build a per‐graph training set
all_train_pairs = []
for idx, graph in enumerate(all_graphs, start=1):
    pos, neg = generate_entity_pairs(graph,
                                     max_distance=400)
    train_pairs = balance_pairs(pos, neg,
                                neg_pos_ratio=5)
    random.shuffle(train_pairs)
    all_train_pairs.append(train_pairs)
    print(f"★ Graph {idx}: sampled {len(train_pairs)} total (pos={len(pos)}, neg={len(train_pairs)-len(pos)})\n")

# Flatten across all graphs
flat_pairs = [pair for gp in all_train_pairs for pair in gp]
print(f"Total training pairs across all graphs: {len(flat_pairs)}")

Loaded 74 graphs from /scratch/vsetpal/results/processed_graphs.pkl

→ Graph has 6 positives, 300 negatives (window=400)
★ Graph 1: sampled 36 total (pos=6, neg=30)

→ Graph has 0 positives, 98 negatives (window=400)
★ Graph 2: sampled 0 total (pos=0, neg=0)

→ Graph has 4 positives, 2836 negatives (window=400)
★ Graph 3: sampled 24 total (pos=4, neg=20)

→ Graph has 0 positives, 1426 negatives (window=400)
★ Graph 4: sampled 0 total (pos=0, neg=0)

→ Graph has 3 positives, 1767 negatives (window=400)
★ Graph 5: sampled 18 total (pos=3, neg=15)

→ Graph has 2 positives, 28 negatives (window=400)
★ Graph 6: sampled 12 total (pos=2, neg=10)

→ Graph has 0 positives, 4976 negatives (window=400)
★ Graph 7: sampled 0 total (pos=0, neg=0)

→ Graph has 34 positives, 1874 negatives (window=400)
★ Graph 8: sampled 204 total (pos=34, neg=170)

→ Graph has 4 positives, 3890 negatives (window=400)
★ Graph 9: sampled 24 total (pos=4, neg=20)

→ Graph has 4 positives, 7280 negatives (window=400)
★ G

In [None]:
import pickle
import torch
import numpy as np
from collections import defaultdict
import random

#Load processed graphs
with open("processed_graphs.pkl", "rb") as f:
    all_graphs = pickle.load(f)
print(f"Loaded {len(all_graphs)} graphs\n")

# Sampling utilities
def generate_entity_pairs(graph, max_distance=400):
    entities = graph["entities"]
    gold = {(r["head"], r["tail"]): r["label"]
            for r in graph.get("relation_labels", [])}

    pos_pairs, neg_pairs = [], []
    for h in entities:
        for t in entities:
            if h is t: continue
            if abs(h["root"] - t["root"]) > max_distance:
                continue
            label = gold.get((h["root"], t["root"]), "NoRelation")
            entry = {
                "head": h,
                "tail": t,
                "head_idx": h["root"],
                "tail_idx": t["root"],
                "label": label
            }
            if label == "NoRelation":
                neg_pairs.append(entry)
            else:
                pos_pairs.append(entry)
    return pos_pairs, neg_pairs

def balance_pairs_1to1(pos_pairs, neg_pairs):
    k = min(len(pos_pairs), len(neg_pairs))
    neg_sample = random.sample(neg_pairs, k)
    return pos_pairs + neg_sample

# Helper to average span embeddings
def average_span_embedding(bert_tensor, start, end):
    if start >= end or end > bert_tensor.size(0):
        return torch.zeros(bert_tensor.size(1), device=bert_tensor.device)
    return bert_tensor[start:end].mean(dim=0)

# Build relation‑level dataset
def extract_relation_dataset_1to1(graphs, max_distance=400, device=None, debug=False):
    X, y = [], []
    relation_label_map = defaultdict(lambda: len(relation_label_map))
    relation_label_map["NoRelation"] = 0

    for gi, graph in enumerate(graphs, 1):
        if debug:
            print(f"\nGraph {gi}/{len(graphs)}: {len(graph['entities'])} entities")

        bert_tensor = torch.tensor(
            graph["node_features"], dtype=torch.float, device=device
        )

        pos, neg = generate_entity_pairs(graph, max_distance)
        pairs = balance_pairs_1to1(pos, neg)

        if debug:
            print(f"  • sampled {len(pairs)} pairs (pos={len(pos)}, neg={len(pairs)-len(pos)})")

        for entry in pairs:
            h, t = entry["head"], entry["tail"]
            lbl = entry["label"]
            lbl_id = relation_label_map[lbl]

            head_vec = average_span_embedding(bert_tensor, h["start"], h["end"])
            tail_vec = average_span_embedding(bert_tensor, t["start"], t["end"])
            pair_vec = torch.cat([head_vec, tail_vec], dim=0).cpu().numpy()

            X.append(pair_vec)
            y.append(lbl_id)

    return np.stack(X), np.array(y), dict(relation_label_map)

# Run and inspect
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X, y, label_map = extract_relation_dataset_1to1(
    all_graphs,
    max_distance=400,
    device=device,
    debug=True
)

print(f"\nFeature matrix shape: {X.shape}")
print(f"Labels vector shape:  {y.shape}")
print("Relation label → ID mapping:")
for lbl, idx in sorted(label_map.items(), key=lambda x: x[1]):
    print(f"   {idx:>3}: {lbl}")

Loaded 74 graphs


Graph 1/74: 18 entities
  • sampled 12 pairs (pos=6, neg=6)

Graph 2/74: 14 entities
  • sampled 0 pairs (pos=0, neg=0)

Graph 3/74: 59 entities
  • sampled 8 pairs (pos=4, neg=4)

Graph 4/74: 60 entities
  • sampled 0 pairs (pos=0, neg=0)

Graph 5/74: 54 entities
  • sampled 6 pairs (pos=3, neg=3)

Graph 6/74: 6 entities
  • sampled 4 pairs (pos=2, neg=2)

Graph 7/74: 123 entities
  • sampled 0 pairs (pos=0, neg=0)

Graph 8/74: 68 entities
  • sampled 68 pairs (pos=34, neg=34)

Graph 9/74: 82 entities
  • sampled 8 pairs (pos=4, neg=4)

Graph 10/74: 99 entities
  • sampled 8 pairs (pos=4, neg=4)

Graph 11/74: 59 entities
  • sampled 28 pairs (pos=14, neg=14)

Graph 12/74: 46 entities
  • sampled 8 pairs (pos=4, neg=4)

Graph 13/74: 20 entities
  • sampled 0 pairs (pos=0, neg=0)

Graph 14/74: 42 entities
  • sampled 0 pairs (pos=0, neg=0)

Graph 15/74: 103 entities
  • sampled 12 pairs (pos=6, neg=6)

Graph 16/74: 32 entities
  • sampled 2 pairs (pos=1, neg=1)

Graph

In [None]:
import os
import json
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import spacy
from transformers import AutoTokenizer, AutoModel
from torch_geometric.nn import GCNConv
from tabulate import tabulate
import higher

# 1. Setup & Model Loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

nlp       = spacy.load("en_core_web_sm")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert      = AutoModel.from_pretrained("bert-base-uncased").to(device)
bert.eval()

# 2. GNN Model
class GNN_RE_NER(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_ner_classes, num_re_classes):
        super().__init__()
        self.gcn1 = GCNConv(input_dim,  hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.ner_classifier = nn.Linear(hidden_dim, num_ner_classes)
        self.re_classifier  = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_re_classes)
        )
    def forward(self, x, edge_index, entity_pairs):
        x = self.gcn1(x, edge_index).relu()
        x = self.gcn2(x, edge_index).relu()
        ner_logits = self.ner_classifier(x)
        pair_repr = []
        for h,t in entity_pairs:
            pair_repr.append(torch.cat([x[h], x[t]], dim=-1))
        if pair_repr:
            re_input = torch.stack(pair_repr)
            re_logits = self.re_classifier(re_input)
        else:
            re_logits = torch.empty(0, self.re_classifier[-1].out_features, device=x.device)
        return ner_logits, re_logits

# 3. Utility Functions
@torch.no_grad()
def get_bert_embeddings(text):
    enc = tokenizer(text, return_tensors="pt", return_offsets_mapping=True,
                    truncation=True, max_length=512).to(device)
    offsets = enc.pop("offset_mapping")[0].tolist()
    out = bert(**enc)
    return tokenizer.convert_ids_to_tokens(enc["input_ids"][0]), out.last_hidden_state.squeeze(0), offsets

def align_to_bert(doc, tokens, embeddings, offsets):
    feats = []
    for tok in doc:
        s,e = tok.idx, tok.idx+len(tok)
        matched = [embeddings[i] 
                   for i,(st,ed) in enumerate(offsets) 
                   if st>=s and ed<=e]
        if matched:
            feats.append(torch.mean(torch.stack(matched),0).cpu().tolist())
        else:
            feats.append([0.]*embeddings.size(-1))
    return feats

def group_tokens_into_entities(token_texts, token_labels, id2label):
    ents, i = [], 0
    while i < len(token_texts):
        lab = id2label[token_labels[i]]
        if lab!="O":
            start = i
            toks  = [token_texts[i]]
            while i+1<len(token_texts) and id2label[token_labels[i+1]]==lab:
                i+=1; toks.append(token_texts[i])
            ents.append({
                "text":" ".join(toks),
                "label":lab,
                "start":start,"end":i+1,"root":start
            })
        i+=1
    return ents

# 4. Load NER Label Map
checkpoint_path = "Your_Checkpoint_Path/checkpoint.pth"
if os.path.exists(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location=device)
    ner_label_map = ckpt["ner_label_map"]
    print("Loaded NER label map from checkpoint.")
else:
    ner_label_map = {"O":0, "PERSON":1, "ORG":2, "GPE":3, "DATE":4}
    print("Defined NER label map manually.")

id2ner = {v:k for k,v in ner_label_map.items()}

# 5. Load & Balance Processed Training Graphs (1:1 positive∶negative)
graphs_pkl = "/scratch/vsetpal/results/processed_graphs.pkl"
with open(graphs_pkl,"rb") as f:
    training_graphs = pickle.load(f)
print(f"Loaded {len(training_graphs)} training graphs")

def augment_with_negatives(graph, max_distance=400):
    positives = list(graph.get("relation_labels", []))
    gold = {(r["head"], r["tail"]) for r in positives}
    neg_cands = []
    for e1 in graph["entities"]:
        for e2 in graph["entities"]:
            if e1["root"]==e2["root"]: continue
            if abs(e1["root"]-e2["root"])>max_distance: continue
            if (e1["root"],e2["root"]) in gold: continue
            neg_cands.append({"head":e1["root"], "tail":e2["root"], "label":"NoRelation"})
    k = min(len(neg_cands), len(positives))
    sampled = random.sample(neg_cands, k) if k>0 else []
    graph["relation_labels"] = positives + sampled

for g in training_graphs:
    augment_with_negatives(g, max_distance=400)
print("Finished 1:1 positive∶negative balancing on each graph.")

# 5b. Rebuild RE label map from balanced graphs
all_re_labels = {"NoRelation"}
for g in training_graphs:
    for r in g["relation_labels"]:
        all_re_labels.add(r["label"])
labels_sorted = ["NoRelation"] + sorted(l for l in all_re_labels if l!="NoRelation")
relation_label_map = {lbl:i for i,lbl in enumerate(labels_sorted)}
id2rel = {i:lbl for lbl,i in relation_label_map.items()}
print("Rebuilt RE label map:", relation_label_map)

# 6. Meta‐Training (MAML)
def split_graph(g):
    rels = g.get("relation_labels",[])
    m = len(rels)//2
    sg, qg = dict(g), dict(g)
    sg["relation_labels"], qg["relation_labels"] = rels[:m], rels[m:]
    return sg,qg

def compute_loss(model, entity_embedder, g):
    nf = torch.tensor(g["node_features"],dtype=torch.float,device=device)
    edge_idx = torch.tensor(g["edges"],dtype=torch.long).t().to(device)
    ner_seq = ["O"]*len(g["nodes"])
    for ent in g.get("entities",[]):
        for i in range(ent["start"],ent["end"]):
            if i<len(ner_seq): ner_seq[i] = ent["label"]
    ner_ids = torch.tensor([ner_label_map.get(t,0) for t in ner_seq],
                           dtype=torch.long,device=device)
    ner_emb = entity_embedder(ner_ids)
    X = torch.cat([nf, ner_emb], dim=-1)

    pairs = [(r["head"],r["tail"]) for r in g.get("relation_labels",[])]
    ner_logits, re_logits = model(X, edge_idx, pairs)

    l_ner = F.cross_entropy(ner_logits, ner_ids)
    if re_logits.size(0)>0:
        lbls = torch.tensor([relation_label_map[r["label"]] 
                             for r in g["relation_labels"]],
                            dtype=torch.long,device=device)
        l_re = F.cross_entropy(re_logits, lbls)
    else:
        l_re = torch.tensor(0., device=device)
    return l_ner + l_re

def meta_train(graphs, ner_label_map, relation_label_map,
               entity_embed_dim=16, hidden_dim=128,
               meta_epochs=50, inner_steps=1, inner_lr=1e-2, meta_lr=1e-3,
               patience=5, min_delta=1e-4):

    model = GNN_RE_NER(
        input_dim=768+entity_embed_dim,
        hidden_dim=hidden_dim,
        num_ner_classes=len(ner_label_map),
        num_re_classes=len(relation_label_map)
    ).to(device)

    entity_embedder = nn.Embedding(len(ner_label_map),entity_embed_dim).to(device)
    opt = torch.optim.Adam(model.parameters(),lr=meta_lr)

    best, wait = float('inf'), 0
    for epoch in range(1,meta_epochs+1):
        meta_loss, cnt = 0.0, 0
        print(f"Meta‐Epoch {epoch}/{meta_epochs}")
        for g in graphs:
            sg, qg = split_graph(g)
            inner_opt = torch.optim.SGD(model.parameters(),lr=inner_lr)
            with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True) as (fmodel, diffopt):
                loss_s = compute_loss(fmodel, entity_embedder, sg)
                for _ in range(inner_steps):
                    diffopt.step(loss_s)
                loss_q = compute_loss(fmodel, entity_embedder, qg)
                meta_loss += loss_q; cnt+=1
        if cnt>0: meta_loss /= cnt

        opt.zero_grad()
        meta_loss.backward()
        opt.step()
        print(f"Avg Query Loss: {meta_loss.item():.4f}")

        if best - meta_loss.item()>min_delta:
            best, wait = meta_loss.item(), 0
        else:
            wait+=1
            if wait>=patience:
                print("Early stopping")
                break

    return model, entity_embedder

# 7. Inference & Annotation
def annotate_file(model, entity_embedder, in_path, out_path, thr=0.0):
    with open(in_path) as f:
        doc = json.load(f)
    text = doc.get("doc", doc.get("document",""))
    sp = nlp(text)
    toks = [t.text for t in sp]
    tokens, embs, offs = get_bert_embeddings(text)
    feats = align_to_bert(sp, tokens, embs, offs)
    edges = [(t.head.i, t.i) for t in sp if t.head.i != t.i]
    nf = torch.tensor(feats, dtype=torch.float, device=device)
    ner_ids = torch.zeros(len(toks), dtype=torch.long, device=device)
    ner_emb = entity_embedder(ner_ids)
    X = torch.cat([nf, ner_emb], dim=-1)
    E = (
        torch.tensor(edges, dtype=torch.long).t().to(device)
        if edges
        else torch.empty((2, 0), dtype=torch.long, device=device)
    )
    with torch.no_grad():
        x1 = model.gcn1(X, E).relu()
        x2 = model.gcn2(x1, E).relu()
        ner_logits = model.ner_classifier(x2)
        pred_ner = ner_logits.argmax(1).cpu().tolist()
    ents = group_tokens_into_entities(toks, pred_ner, id2ner)
    cand = [
        (e1["start"], e2["start"])
        for i, e1 in enumerate(ents)
        for j, e2 in enumerate(ents)
        if i != j
    ]
    triples = []
    if cand:
        reps = [torch.cat([x2[h], x2[t]], 0) for h, t in cand]
        inp = torch.stack(reps)
        with torch.no_grad():
            logits = model.re_classifier(inp)
            probs = F.softmax(logits, dim=1)
            conf, idxs = probs.max(1)
        for i, (h, t) in enumerate(cand):
            lab = id2rel[idxs[i].item()]
            c = conf[i].item()
            if lab != "NoRelation" and c >= thr:
                triples.append({
                    "head": ents[[e["start"] for e in ents].index(h)]["text"],
                    "tail": ents[[e["start"] for e in ents].index(t)]["text"],
                    "label": lab,
                    "conf": c
                })

    # carry gold labels through so eval can see them
    out = {
        "document":        text,
        "pred_entities":   ents,
        "pred_triples":    triples,
        "NER-label_set":   doc.get("NER-label_set", []),
        "RE_label_set":    doc.get("RE_label_set",  []),
        "id":              doc.get("id", "")
    }

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w") as fw:
        json.dump(out, fw, indent=2)
    return out

def annotate_all(model, entity_embedder, root_in, root_out, thr=0.0):
    for dp, _, files in os.walk(root_in):
        rel = os.path.relpath(dp, root_in)
        od  = os.path.join(root_out, rel)
        os.makedirs(od, exist_ok=True)
        for fn in files:
            if not fn.endswith(".json"):
                continue
            inp = os.path.join(dp, fn)
            out = os.path.join(od, fn)
            annotate_file(model, entity_embedder, inp, out, thr)

# 8. Evaluation
def eval_file(pred_path, gold_path):
    pred = json.load(open(pred_path))
    gold = json.load(open(gold_path))

    # ——— NER type‐level eval ———
    gold_ner = set(gold.get("NER-label_set", []))
    pred_ner = set(ent["label"] for ent in pred["pred_entities"])
    tp = len(gold_ner & pred_ner)
    fp = len(pred_ner - gold_ner)
    fn = len(gold_ner - pred_ner)
    precision_ner = tp / (tp + fp + 1e-8)
    recall_ner    = tp / (tp + fn + 1e-8)
    f1_ner        = 2 * precision_ner * recall_ner / (precision_ner + recall_ner + 1e-8)

    # ——— RE type‐level eval ———
    gold_re = set(gold.get("RE_label_set", []))
    pred_re = set(tr["label"] for tr in pred["pred_triples"])
    tp = len(gold_re & pred_re)
    fp = len(pred_re - gold_re)
    fn = len(gold_re - pred_re)
    precision_re = tp / (tp + fp + 1e-8)
    recall_re    = tp / (tp + fn + 1e-8)
    f1_re        = 2 * precision_re * recall_re / (precision_re + recall_re + 1e-8)

    return precision_ner, recall_ner, f1_ner, precision_re, recall_re, f1_re


def eval_all(pred_root, gold_root):
    table = []
    overall = [0.0]*6
    overall_n = 0

    for dp, _, files in os.walk(pred_root):
        rel = os.path.relpath(dp, pred_root)
        sums = [0.0]*6
        n = 0
        for fn in files:
            if not fn.endswith(".json"):
                continue
            pred_path = os.path.join(dp, fn)
            gold_path = os.path.join(gold_root, rel, fn)
            if not os.path.exists(gold_path):
                continue
            p1, r1, f1n, p2, r2, f1r = eval_file(pred_path, gold_path)
            for i, v in enumerate((p1, r1, f1n, p2, r2, f1r)):
                sums[i] += v
                overall[i] += v
            n += 1
            overall_n += 1

        if n > 0:
            table.append([
                os.path.basename(dp),
                *(f"{s/n:.4f}" for s in sums[:3]),
                *(f"{s/n:.4f}" for s in sums[3:])
            ])

if __name__=="__main__":
    print("Meta‐training GNN_RE_NER")
    model, entity_embedder = meta_train(
        training_graphs,
        ner_label_map,
        relation_label_map,
        entity_embed_dim=16,
        hidden_dim=128,
        meta_epochs=50,
        inner_steps=1,
        inner_lr=1e-2,
        meta_lr=1e-3,
        patience=5,
        min_delta=1e-4
    )

    # Save checkpoint
    torch.save({
        "model_state_dict": model.state_dict(),
        "entity_embedder_state_dict": entity_embedder.state_dict(),
        "ner_label_map": ner_label_map,
        "relation_label_map": relation_label_map
    }, checkpoint_path)
    print(f"Saved meta‐trained checkpoint to {checkpoint_path}")

    # 9.3 Annotate test set
    TEST_IN  = "Your_Test_Set_Path"
    TEST_OUT = "Your_Test_Set_Annotated_Path"
    print("Annotating all test documents")
    annotate_all(model, entity_embedder, TEST_IN, TEST_OUT, thr=0.0)

    # 9.4 Evaluate
    print("Evaluation across domains")
    eval_all(TEST_OUT, TEST_IN)

Using device: cuda
Loaded NER label map from checkpoint.
Loaded 74 training graphs
Finished 1:1 positive∶negative balancing on each graph.
Rebuilt RE label map: {'NoRelation': 0, 'AcademicDegree': 1, 'AdjacentStation': 2, 'Affiliation': 3, 'AppliesToPeople': 4, 'ApprovedBy': 5, 'Author': 6, 'AwardReceived': 7, 'BasedOn': 8, 'Capital': 9, 'Causes': 10, 'CitesWork': 11, 'ContainsAdministrativeTerritorialEntity': 12, 'ContainsTheAdministrativeTerritorialEntity': 13, 'Continent': 14, 'ContributedToCreativeWork': 15, 'Country': 16, 'CountryOfCitizenship': 17, 'Creator': 18, 'Developer': 19, 'DifferentFrom': 20, 'DiplomaticRelation': 21, 'Director': 22, 'EducatedAt': 23, 'Employer': 24, 'FieldOfWork': 25, 'FollowedBy': 26, 'Follows': 27, 'Founded': 28, 'FoundedBy': 29, 'HasCause': 30, 'HasEffect': 31, 'HasPart': 32, 'HasQuality': 33, 'HasWorksInTheCollection': 34, 'InOppositionTo': 35, 'InfluencedBy': 36, 'InspiredBy': 37, 'InterestedIn': 38, 'IssuedBy': 39, 'LanguageOfWorkOrName': 40, 'Lang

In [None]:
import os
import json
from tabulate import tabulate

PRED_DIR  = "Your_Predictions_Path"
GOLD_DIR  = "Your_Test_Path"
CONF_THR  = 0.018

def prf(gold_set, pred_set):
    tp = len(gold_set & pred_set)
    fp = len(pred_set - gold_set)
    fn = len(gold_set - pred_set)
    p  = tp / (tp + fp) if tp + fp > 0 else 0.0
    r  = tp / (tp + fn) if tp + fn > 0 else 0.0
    f1 = 2*p*r/(p+r)     if p + r    > 0 else 0.0
    return p, r, f1

table = []
for pred_root, _, files in os.walk(PRED_DIR):
    rel    = os.path.relpath(pred_root, PRED_DIR)
    domain = rel if rel != "." else "<root>"
    ner_tp = ner_fp = ner_fn = 0
    re_tp  = re_fp  = re_fn  = 0
    n_docs = 0
    for fn in files:
        if not fn.endswith(".json"):
            continue
        pred_path = os.path.join(pred_root, fn)
        gold_path = os.path.join(GOLD_DIR,   rel, fn)
        if not os.path.exists(gold_path):
            continue
        pred = json.load(open(pred_path))
        gold = json.load(open(gold_path))
        gold_ner = set(gold.get("NER-label_set", []))
        pred_ner = set(ent["label"] for ent in pred.get("pred_entities", []))
        ner_tp += len(gold_ner & pred_ner)
        ner_fp += len(pred_ner - gold_ner)
        ner_fn += len(gold_ner - pred_ner)
        gold_re = set(gold.get("RE_label_set", []))
        pred_re = {
            tr["label"]
            for tr in pred.get("pred_triples", [])
            if tr.get("conf", 0.0) >= CONF_THR
        }
        re_tp  += len(gold_re & pred_re)
        re_fp  += len(pred_re - gold_re)
        re_fn  += len(gold_re - pred_re)
        n_docs += 1
    if n_docs == 0:
        continue
    ner_p  = ner_tp / (ner_tp + ner_fp) if ner_tp + ner_fp > 0 else 0.0
    ner_r  = ner_tp / (ner_tp + ner_fn) if ner_tp + ner_fn > 0 else 0.0
    ner_f1 = 2*ner_p*ner_r/(ner_p+ner_r) if ner_p+ner_r>0 else 0.0
    re_p   = re_tp  / (re_tp  + re_fp)  if re_tp  + re_fp  > 0 else 0.0
    re_r   = re_tp  / (re_tp  + re_fn)  if re_tp  + re_fn  > 0 else 0.0
    re_f1  = 2*re_p * re_r /(re_p + re_r) if re_p + re_r >0 else 0.0
    c_p   = (ner_p + re_p) / 2
    c_r   = (ner_r + re_r) / 2
    c_f1  = (ner_f1 + re_f1) / 2

    table.append([
        domain,
        f"{c_p:.4f}",
        f"{c_r:.4f}",
        f"{c_f1:.4f}"
    ])

print(tabulate(
    table,
    headers=["Domain", "P", "R", "F1"],
    tablefmt="pretty"
))

+----------------------+--------+--------+--------+
|        Domain        |   P    |   R    |   F1   |
+----------------------+--------+--------+--------+
| Academic_disciplines | 0.8714 | 0.2450 | 0.3616 |
|       Business       | 0.6538 | 0.2637 | 0.3726 |
|    Communication     | 0.7000 | 0.2632 | 0.3789 |
|       Culture        | 0.7212 | 0.2521 | 0.3611 |
|       Economy        | 0.7500 | 0.2571 | 0.3702 |
|      Education       | 0.8256 | 0.2767 | 0.4040 |
|        Energy        | 0.7500 | 0.2385 | 0.3431 |
|     Engineering      | 0.6270 | 0.2257 | 0.3296 |
|    Entertainment     | 0.7281 | 0.2687 | 0.3773 |
|    Food_and_drink    | 0.7195 | 0.2416 | 0.3507 |
|      Geography       | 0.7656 | 0.2629 | 0.3792 |
|      Government      | 0.7683 | 0.2608 | 0.3697 |
|        Health        | 0.6944 | 0.2508 | 0.3601 |
|       History        | 0.7857 | 0.2767 | 0.3945 |
|    Human_behavior    | 0.6250 | 0.2425 | 0.3460 |
|      Humanities      | 0.7604 | 0.2698 | 0.3881 |
|     Inform