In [None]:
import os
import json
import torch
import spacy
from transformers import AutoTokenizer, AutoModel
from difflib import SequenceMatcher
import pandas as pd

# Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")

# Load Models
print("Loading models...")
nlp = spacy.load("en_core_web_sm")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModel.from_pretrained("bert-base-uncased").to(device)
bert_model.eval()
print("Models loaded.\n")

# Set folder path containing JSON files
folder_path = 'Your_Path_Here'  # Change this to your folder path
json_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.json')]
print(f"Found {len(json_files)} JSON files in: {folder_path}")

@torch.no_grad()
def get_bert_embeddings(text: str):
    encoded = tokenizer(text, return_tensors="pt", return_offsets_mapping=True, truncation=True).to(device)
    offsets = encoded.pop("offset_mapping")[0].tolist()
    output = bert_model(**encoded)
    embeddings = output.last_hidden_state.squeeze(0)  # convert back to CPU for later processing
    tokens = tokenizer.convert_ids_to_tokens(encoded["input_ids"][0])
    return tokens, embeddings, offsets

def align_tokens_to_bert(doc, tokens, embeddings, offsets):
    token_embeddings = []
    for token in doc:
        token_start = token.idx
        token_end = token.idx + len(token.text)
        matched = [embeddings[i] for i, (start, end) in enumerate(offsets)
                   if start >= token_start and end <= token_end]
        if matched:
            token_embeddings.append(torch.mean(torch.stack(matched), dim=0).tolist())
        else:
            token_embeddings.append(torch.zeros(embeddings.size(-1)).tolist())
    return token_embeddings

def fuzzy_match_span(entity_text, doc):
    entity_text = entity_text.lower()
    best_score = 0.0
    best_span = None
    for start in range(len(doc)):
        for end in range(start + 1, min(len(doc) + 1, start + 10)):
            span = doc[start:end]
            score = SequenceMatcher(None, entity_text, span.text.lower()).ratio()
            if score > best_score:
                best_score = score
                best_span = span
    return best_span

def process_document(doc_entry):
    doc_text = doc_entry["doc"]
    triples = doc_entry.get("triples", [])
    entity_defs = doc_entry.get("entities", [])

    doc = nlp(doc_text)
    tokens, embeddings, offsets = get_bert_embeddings(doc_text)
    features = align_tokens_to_bert(doc, tokens, embeddings, offsets)

    graph = {
        "nodes": [token.text for token in doc],
        "edges": [(token.head.i, token.i) for token in doc if token.head.i != token.i],
        "node_features": features,
        "relation_labels": [],
        "entities": []
    }

    for ent in entity_defs:
        ent_type = ent["type"]
        for mention in ent["mentions"]:
            span = fuzzy_match_span(mention, doc)
            if span:
                graph["entities"].append({
                    "text": mention,
                    "start": span.start,
                    "end": span.end,
                    "label": ent_type,
                    "root": span.root.i
                })

    seen_pairs = set()
    for triple in triples:
        head_span = fuzzy_match_span(triple["head"], doc)
        tail_span = fuzzy_match_span(triple["tail"], doc)

        if head_span and tail_span:
            head_idx = head_span.root.i
            tail_idx = tail_span.root.i
            if (head_idx, tail_idx) not in seen_pairs:
                graph["relation_labels"].append({
                    "head": head_idx,
                    "tail": tail_idx,
                    "label": triple["relation"]
                })
                seen_pairs.add((head_idx, tail_idx))

    return graph

# Process all JSON files
all_graphs = []
for file_idx, file_path in enumerate(json_files):
    with open(file_path, "r") as f:
        docs = json.load(f)

    print(f"Processing {os.path.basename(file_path)} ({len(docs)} docs)...")
    for i, doc_entry in enumerate(docs):
        if "doc" in doc_entry and "triples" in doc_entry:
            graph = process_document(doc_entry)
            all_graphs.append(graph)

print(f"\nProcessed {len(all_graphs)} graphs from {len(json_files)} files.\n")

# Display Summary Table
df_preview = pd.DataFrame([{
    "num_nodes": len(g["nodes"]),
    "num_edges": len(g["edges"]),
    "num_relations": len(g["relation_labels"]),
    "num_entities": len(g["entities"]),
    "sample_nodes": g["nodes"][:5],
    "sample_relations": g["relation_labels"][:3],
    "sample_entities": g["entities"][:3]
} for g in all_graphs[:5]])

print("Graph Summary:")
print(df_preview.to_string(index=False))

Using device: cuda

Loading models...


  warn(f"Failed to load image Python extension: {e}")


Models loaded.

Found 7 JSON files in: /scratch/vsetpal/train/train
Processing Communication_all_examples.json (10 docs)...
Processing Education_all_examples.json (10 docs)...
Processing Energy_all_examples.json (10 docs)...
Processing Entertainment_all_examples.json (12 docs)...
Processing Government_all_examples.json (9 docs)...
Processing Human_behavior_all_examples.json (13 docs)...
Processing Internet_all_examples.json (10 docs)...

Processed 74 graphs from 7 files.

Graph Summary:
 num_nodes  num_edges  num_relations  num_entities                                 sample_nodes                                                                                                                                           sample_relations                                                                                                                                                                                                                                                                  

In [None]:
import pickle

with open("processed_graphs.pkl", "wb") as f:
    pickle.dump(all_graphs, f)
print("Saved all graphs to disk.")

Saved all graphs to disk.


In [None]:
import torch
import torch.nn as nn
import pickle

# Load preprocessed graphs from pickle
with open("processed_graphs.pkl", "rb") as f:
    all_graphs = pickle.load(f)

print(f"Loaded {len(all_graphs)} graphs from pickle.\n")

# Dynamically build entity label vocabulary from the dataset
unique_entity_types = set()
for graph in all_graphs:
    for ent in graph.get("entities", []):
        unique_entity_types.add(ent["label"])

print("Unique entity types in dataset:")
print(sorted(unique_entity_types))

ENTITY_LABELS = {label: idx + 1 for idx, label in enumerate(sorted(unique_entity_types))}
ENTITY_LABELS["O"] = 0  # non-entity/default class

NUM_ENTITY_TYPES = len(ENTITY_LABELS)
ENTITY_EMBED_DIM = 16

print(f"\nENTITY_LABELS mapping ({NUM_ENTITY_TYPES} types):")
print(ENTITY_LABELS)

# Select one graph to test
graph = all_graphs[0]
print(f"\nGraph with {len(graph['nodes'])} nodes and {len(graph['entities'])} entities")

# Create entity type tags per node
node_entity_types = ["O"] * len(graph["nodes"])
for ent in graph["entities"]:
    for i in range(ent["start"], ent["end"]):
        if i < len(node_entity_types):
            node_entity_types[i] = ent["label"]

# Map entity types to IDs
entity_type_ids = [ENTITY_LABELS.get(label, 0) for label in node_entity_types]
entity_type_ids_tensor = torch.tensor(entity_type_ids, dtype=torch.long)

print(f"\nFirst 50 entity labels:\n{node_entity_types[0:10]}")
print(f"First 50 entity type IDs:\n{entity_type_ids[0:10]}")

# Convert BERT features to tensor
bert_tensor = torch.tensor(graph["node_features"], dtype=torch.float)
print(f"\nBERT features shape: {bert_tensor.shape}")

# Embed entity type IDs
entity_embedding = nn.Embedding(num_embeddings=NUM_ENTITY_TYPES, embedding_dim=ENTITY_EMBED_DIM)
entity_vecs = entity_embedding(entity_type_ids_tensor)
print(f"Entity embeddings shape: {entity_vecs.shape}")

# Concatenate BERT + Entity embeddings
node_features = torch.cat([bert_tensor, entity_vecs], dim=1)
print(f"\nCombined node feature shape: {node_features.shape}")
print(f"Sample node feature vector (first token):\n{node_features[0]}")

Loaded 74 graphs from pickle.

Unique entity types in dataset:
['CARDINAL', 'DATE', 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 'MISC', 'MONEY', 'NORP', 'ORDINAL', 'ORG', 'PERCENT', 'PERSON', 'PRODUCT', 'QUANTITY', 'TIME', 'WORK_OF_ART']

ENTITY_LABELS mapping (20 types):
{'CARDINAL': 1, 'DATE': 2, 'EVENT': 3, 'FAC': 4, 'GPE': 5, 'LANGUAGE': 6, 'LAW': 7, 'LOC': 8, 'MISC': 9, 'MONEY': 10, 'NORP': 11, 'ORDINAL': 12, 'ORG': 13, 'PERCENT': 14, 'PERSON': 15, 'PRODUCT': 16, 'QUANTITY': 17, 'TIME': 18, 'WORK_OF_ART': 19, 'O': 0}

Graph with 195 nodes and 18 entities

First 50 entity labels:
['O', 'O', 'MISC', 'MISC', 'MISC', 'O', 'O', 'O', 'O', 'MISC']
First 50 entity type IDs:
[0, 0, 9, 9, 9, 0, 0, 0, 0, 9]

BERT features shape: torch.Size([195, 768])
Entity embeddings shape: torch.Size([195, 16])

Combined node feature shape: torch.Size([195, 784])
Sample node feature vector (first token):
tensor([-3.9604e-01, -2.4289e-01, -1.5405e-01,  3.4880e-01,  3.2767e-03,
         1.9410e-02,  3.

In [4]:
def generate_entity_pairs(graph):
    """
    Generates all possible head-tail entity pairs for a graph.
    Label them based on existing relation_labels.
    """
    entities = graph["entities"]
    relation_labels = {(rel["head"], rel["tail"]): rel["label"] for rel in graph["relation_labels"]}
    print(f"\nGenerating candidate pairs from {len(entities)} entities...")
    print(f"Gold relation label count: {len(relation_labels)}")
    candidate_pairs = []
    count_pos = 0
    count_neg = 0
    for i, head in enumerate(entities):
        for j, tail in enumerate(entities):
            if i == j:
                continue

            head_root = head["root"]
            tail_root = tail["root"]
            label = relation_labels.get((head_root, tail_root), "NoRelation")

            if label != "NoRelation":
                print(f"MATCH: '{head['text']}' → '{tail['text']}' labeled as {label}")
                count_pos += 1
            else:
                count_neg += 1

            candidate_pairs.append({
                "head": head,
                "tail": tail,
                "head_idx": head_root,
                "tail_idx": tail_root,
                "label": label
            })
    print(f"\nGenerated {len(candidate_pairs)} total pairs")
    print(f"{count_pos} positive (labeled) | {count_neg} negative (NoRelation)")
    return candidate_pairs

total_pairs = 0
total_pos = 0
total_neg = 0

for i, graph in enumerate(all_graphs):
    print(f"\nProcessing Graph {i+1}/{len(all_graphs)}")
    pairs = generate_entity_pairs(graph)
    total_pairs += len(pairs)
    total_pos += sum(1 for p in pairs if p['label'] != "NoRelation")
    total_neg += sum(1 for p in pairs if p['label'] == "NoRelation")

print("\nFINAL SUMMARY:")
print(f"Total Graphs: {len(all_graphs)}")
print(f"Total Candidate Pairs: {total_pairs}")
print(f"Total Positive (Labeled): {total_pos}")
print(f"Total Negative (NoRelation): {total_neg}")


Processing Graph 1/74

Generating candidate pairs from 18 entities...
Gold relation label count: 7
MATCH: 'William Thomson' → 'automatic curb sender' labeled as Creator
MATCH: 'Wheatstone transmitter' → 'land line' labeled as LocatedIn
MATCH: 'automatic curb sender' → 'cable' labeled as LocatedIn
MATCH: 'automatic curb sender' → 'Eastern Telegraph Company' labeled as UsedBy
MATCH: 'automatic curb sender' → 'telegraph key' labeled as PartOf
MATCH: 'automatic curb sender' → 'Eastern Telegraph Company' labeled as UsedBy

Generated 306 total pairs
6 positive (labeled) | 300 negative (NoRelation)

Processing Graph 2/74

Generating candidate pairs from 14 entities...
Gold relation label count: 7

Generated 182 total pairs
0 positive (labeled) | 182 negative (NoRelation)

Processing Graph 3/74

Generating candidate pairs from 59 entities...
Gold relation label count: 4
MATCH: 'China Internet Network Information Center' → 'Anti-Phishing Alliance of China (APAC)' labeled as MemberOf
MATCH: 'Ch

MATCH: 'field trip' → 'site-based program' labeled as HasPart
MATCH: 'field trip' → 'follow-up activity' labeled as HasPart
MATCH: 'field trip' → 'Activities' labeled as HasPart
MATCH: 'field trip' → 'activities' labeled as HasPart

Generated 10506 total pairs
7 positive (labeled) | 10499 negative (NoRelation)

Processing Graph 16/74

Generating candidate pairs from 32 entities...
Gold relation label count: 1
MATCH: 'HEFCE' → 'higher education funding' labeled as HasPart

Generated 992 total pairs
1 positive (labeled) | 991 negative (NoRelation)

Processing Graph 17/74

Generating candidate pairs from 45 entities...
Gold relation label count: 1
MATCH: 'Hospitality Management and Tourism' → 'hospitality industry' labeled as Studies
MATCH: 'Hospitality Management' → 'hospitality industry' labeled as Studies
MATCH: 'Hospitality management' → 'hospitality industry' labeled as Studies
MATCH: 'hospitality management' → 'hospitality industry' labeled as Studies

Generated 1980 total pairs
4 p

MATCH: 'Kim Conley' → 'U.S.' labeled as LocatedIn
MATCH: 'Kim Conley' → 'U.S.' labeled as LocatedIn
MATCH: 'Managing Editor Kim Conley' → 'U.S.' labeled as LocatedIn
MATCH: 'Managing Editor Kim Conley' → 'U.S.' labeled as LocatedIn
MATCH: 'New York' → 'U.S.' labeled as NominatedFor
MATCH: 'New York' → 'U.S.' labeled as NominatedFor
MATCH: 'Lifestyle Media, Inc.' → 'Macfadden Performing Arts Media, LLC' labeled as InterestedIn
MATCH: 'Lifestyle Media, Inc.' → 'Macfadden Performing Arts Media' labeled as InterestedIn
MATCH: 'CBS' → 'U.S.' labeled as LocatedIn
MATCH: 'CBS' → 'U.S.' labeled as LocatedIn
MATCH: 'Debby Ryan' → 'U.S.' labeled as InterestedIn
MATCH: 'Debby Ryan' → 'U.S.' labeled as InterestedIn
MATCH: 'Kendall Jenner' → 'U.S.' labeled as InterestedIn
MATCH: 'Kendall Jenner' → 'U.S.' labeled as InterestedIn
MATCH: 'Heather Morris' → 'U.S.' labeled as InterestedIn
MATCH: 'Heather Morris' → 'U.S.' labeled as InterestedIn
MATCH: 'Christina Milian' → 'U.S.' labeled as InterestedIn


MATCH: 'Charles II' → 'Francia Occidentalis' labeled as OwnerOf
MATCH: 'Charles II' → 'Kingdom of Aquitaine' labeled as PartOf
MATCH: 'Charles II' → 'West Francia' labeled as AdjacentStation
MATCH: 'Louis the German' → 'Kingdom of Bavaria' labeled as PartOf
MATCH: 'Lothair I' → 'Alsace' labeled as AdjacentStation
MATCH: 'Lothair I' → 'Francia Media' labeled as OwnerOf
MATCH: 'Lothair I' → 'Lorraine' labeled as AdjacentStation
MATCH: 'Lothair I' → 'Imperial throne' labeled as NominatedFor
MATCH: 'Lothair I' → 'Kingdom of Italy' labeled as AdjacentStation
MATCH: 'Lothair I' → 'Low Countries' labeled as AdjacentStation
MATCH: 'Lothair I' → 'Rhineland' labeled as AdjacentStation
MATCH: 'Lothair I' → 'imperial throne' labeled as NominatedFor
MATCH: 'Lothair I' → 'imperial throne' labeled as NominatedFor
MATCH: 'Lothair I' → 'Imperial throne' labeled as NominatedFor
MATCH: 'Louis II' → 'Francia Orientalis' labeled as OwnerOf
MATCH: 'Louis II' → 'East Francia' labeled as AdjacentStation
MATCH

In [5]:
import torch
import numpy as np
from collections import defaultdict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Label-to-ID map
relation_label_map = defaultdict(lambda: len(relation_label_map))
relation_label_map["NoRelation"] = 0  # reserve 0 for negatives

def average_span_embedding(bert_tensor, start, end):
    """Average BERT embedding over a span of tokens."""
    if start >= end or end > bert_tensor.shape[0]:
        return torch.zeros(bert_tensor.shape[1], device=bert_tensor.device)
    span_vecs = bert_tensor[start:end]
    return span_vecs.mean(dim=0)

def extract_relation_dataset(graphs, use_debug=True):
    X, y = [], []

    print(f"\nStarting feature extraction on {len(graphs)} graphs...")

    for g_idx, graph in enumerate(graphs):
        if use_debug:
            print(f"\nGraph {g_idx+1}/{len(graphs)}: {len(graph['entities'])} entities")

        # Convert BERT features to tensor
        bert_tensor = torch.tensor(graph["node_features"], dtype=torch.float).to(device)

        # Build lookup for gold relations using root indices
        gold_relations = {(rel["head"], rel["tail"]): rel["label"] for rel in graph["relation_labels"]}

        for i, head in enumerate(graph["entities"]):
            for j, tail in enumerate(graph["entities"]):
                if i == j:
                    continue

                head_idx = head["root"]
                tail_idx = tail["root"]
                label = gold_relations.get((head_idx, tail_idx), "NoRelation")
                label_id = relation_label_map[label]

                # Get span embeddings
                head_vec = average_span_embedding(bert_tensor, head["start"], head["end"])
                tail_vec = average_span_embedding(bert_tensor, tail["start"], tail["end"])

                # Combine features
                pair_vec = torch.cat([head_vec, tail_vec], dim=0).cpu().numpy()
                X.append(pair_vec)
                y.append(label_id)

                if use_debug and label != "NoRelation":
                    print(f"{head['text']} → {tail['text']} | Label: {label} (ID: {label_id})")
    print(f"\nTotal relation examples: {len(X)}")
    print(f"Class distribution:")
    for lbl, idx in sorted(relation_label_map.items(), key=lambda x: x[1]):
        count = sum(1 for yy in y if yy == idx)
        print(f"  {lbl:<35}: {count}")

    return np.array(X), np.array(y), dict(relation_label_map)

# Run the feature extraction
X, y, label_map = extract_relation_dataset(all_graphs, use_debug=True)

print(f"\nFeature shape: {X.shape}")
print(f"Labels shape:  {y.shape}")
print(f"Label map:     {label_map}")


Starting feature extraction on 74 graphs...

Graph 1/74: 18 entities
William Thomson → automatic curb sender | Label: Creator (ID: 1)
Wheatstone transmitter → land line | Label: LocatedIn (ID: 2)
automatic curb sender → cable | Label: LocatedIn (ID: 2)
automatic curb sender → Eastern Telegraph Company | Label: UsedBy (ID: 3)
automatic curb sender → telegraph key | Label: PartOf (ID: 4)
automatic curb sender → Eastern Telegraph Company | Label: UsedBy (ID: 3)

Graph 2/74: 14 entities

Graph 3/74: 59 entities
China Internet Network Information Center → Anti-Phishing Alliance of China (APAC) | Label: MemberOf (ID: 5)
China Internet Network Information Center → Anti-Phishing Alliance of China | Label: MemberOf (ID: 5)
China Internet Network Information Center → .cn domain name | Label: OwnerOf (ID: 6)
China Internet Network Information Center → Ministry of Industry and Information Technology | Label: Affiliation (ID: 7)
China Internet Network Information Center → Zhongguancun | Label: Loc


Graph 21/74: 50 entities
Müllverwertung Rugenberger Damm WtE plant → Hamburg | Label: LocatedIn (ID: 2)
Müllverwertung Rugenberger Damm WtE plant → Hamburg, Germany | Label: LocatedIn (ID: 2)
Klean Power → advanced thermal recycling system | Label: OwnerOf (ID: 6)

Graph 22/74: 22 entities

Graph 23/74: 13 entities

Graph 24/74: 149 entities
Edison → Pearl Street Power Station | Label: OwnerOf (ID: 6)
Thomas Edison → Pearl Street Power Station | Label: OwnerOf (ID: 6)
New York City → America | Label: LocatedIn (ID: 2)
New York City → North America | Label: LocatedIn (ID: 2)

Graph 25/74: 32 entities
Energy Management Software → real-time energy metering | Label: HasPart (ID: 12)
Energy Management Software → grid services | Label: HasPart (ID: 12)
Energy Management Software → utility bill tracking | Label: HasPart (ID: 12)
Energy Management Software → generation control | Label: HasPart (ID: 12)
Energy Management Software → carbon and sustainability reporting | Label: HasPart (ID: 12)


Siege of Gibraltar → theatre | Label: SignificantEvent (ID: 25)
Siege of Gibraltar → Sadler's Wells Theatre | Label: SignificantEvent (ID: 25)
Siege of Gibraltar → Wells Theatre | Label: SignificantEvent (ID: 25)
Siege of Gibraltar → theatre | Label: SignificantEvent (ID: 25)
Siege of Gibraltar → Theatre | Label: SignificantEvent (ID: 25)
Siege of Gibraltar → aqua drama | Label: SignificantEvent (ID: 25)
HMS Bounty → Island or Christian and His Comrades | Label: InterestedIn (ID: 27)
HMS Bounty → Island | Label: InterestedIn (ID: 27)
Napoleonic wars → early 1800s | Label: SignificantEvent (ID: 25)
Napoleonic wars → early 1800s | Label: SignificantEvent (ID: 25)
Napoleonic wars → aqua drama | Label: HasEffect (ID: 39)

Graph 33/74: 60 entities
Delfines Hotel & Convention Center → Delfines Hotel & Casino | Label: SaidToBeTheSameAs (ID: 37)
Delfines Hotel & Convention Center → Hotel Los Delfines | Label: SaidToBeTheSameAs (ID: 37)
Delfines Hotel & Convention Center → San Isidro District, 


Graph 35/74: 52 entities
elegiac comedies → vernacular language | Label: Uses (ID: 17)
vernacular writers → Boccaccio | Label: HasPart (ID: 12)
vernacular writers → Chaucer | Label: HasPart (ID: 12)
vernacular writers → Gower | Label: HasPart (ID: 12)
Thompson → elegiac comedies | Label: HasEffect (ID: 39)

Graph 36/74: 63 entities
Russell Lincoln Ackoff → 1967 | Label: SignificantEvent (ID: 25)
Howard → Metagame Analysis in Political Problems | Label: Author (ID: 30)
Howard → Metagame Analysis | Label: Author (ID: 30)
Nigel Howard → Metagame Analysis in Political Problems | Label: Author (ID: 30)
Nigel Howard → Metagame Analysis | Label: Author (ID: 30)
Nigel Howard's book → 1971 | Label: Founded (ID: 40)
Garfield → 1995 | Label: SignificantEvent (ID: 25)
Garfield → Magic: The Gathering | Label: Creator (ID: 1)
Richard Garfield → 1995 | Label: SignificantEvent (ID: 25)
Richard Garfield → Magic: The Gathering | Label: Creator (ID: 1)
Aron Nimzowitsch → hypermodern openings | Label: Ha

American → U.S. | Label: SaidToBeTheSameAs (ID: 37)
American Cheerleader magazine → 1995 | Label: Founded (ID: 40)
American Cheerleader magazine → Lifestyle Ventures | Label: OwnerOf (ID: 6)
American Cheerleader magazine → Lifestyle Ventures, LLC | Label: OwnerOf (ID: 6)
U.S. → New York | Label: NominatedFor (ID: 9)
U.S. → U.S. | Label: SaidToBeTheSameAs (ID: 37)
U.S. → California | Label: NominatedFor (ID: 9)
U.S. → Texas | Label: NominatedFor (ID: 9)
USASF → Varsity Spirit, LLC | Label: Follows (ID: 34)
USASF → Varsity Spirit | Label: Follows (ID: 34)
Jeff Webb → Varsity Brands | Label: Creator (ID: 1)
Jeff Webb → Varsity Spirit, LLC | Label: Creator (ID: 1)
Jeff Webb → Varsity Spirit | Label: Creator (ID: 1)
Jeff Webb → National Cheerleaders Association | Label: InterestedIn (ID: 27)
National Cheerleaders Association → U.S. | Label: NominatedFor (ID: 9)
National Cheerleaders Association → U.S. | Label: NominatedFor (ID: 9)
National College Cheerleading Championship → 1978 | Label: F

Emperor Haile Selassie → Derg | Label: ParentOrganization (ID: 49)
Haile Selassie → Derg | Label: ParentOrganization (ID: 49)
1270 → two | Label: PrimeFactor (ID: 47)
Tigray People's Liberation Front → Ethiopian People's Revolutionary Democratic Front | Label: HasPart (ID: 12)
Tigray People's Liberation Front → Ethiopian People's Revolutionary Democratic Front (EPRDF) | Label: HasPart (ID: 12)
Tigray People's Liberation Front (TPLF) → Ethiopian People's Revolutionary Democratic Front | Label: HasPart (ID: 12)
Tigray People's Liberation Front (TPLF) → Ethiopian People's Revolutionary Democratic Front (EPRDF) | Label: HasPart (ID: 12)
EPLF → Eritrean Liberation Front Army | Label: HasPart (ID: 12)
EPLF → Eritrean Liberation Front Army (EPLA) | Label: HasPart (ID: 12)
Eritrean People's Liberation Front → Derg | Label: Follows (ID: 34)
Eritrean People's Liberation Front (EPLF) → Derg | Label: Follows (ID: 34)
Tigray People's Liberation Front (TPLF) → Ethiopian People's Revolutionary Democr

Quality of Nationality Index → Forbes | Label: PublishedIn (ID: 61)
Quality of Nationality Index → Forbes, Bloomberg, The Enquirer and Business Standard | Label: PublishedIn (ID: 61)
Quality of Nationality Index → economic strength, human development, ease of travel, political stability and overseas employment opportunities | Label: HasPart (ID: 12)
Quality of Nationality Index (QNI) → Ayelet Shachar | Label: CitesWork (ID: 60)
Quality of Nationality Index (QNI) → Forbes | Label: PublishedIn (ID: 61)
Quality of Nationality Index (QNI) → Forbes, Bloomberg, The Enquirer and Business Standard | Label: PublishedIn (ID: 61)
Quality of Nationality Index (QNI) → economic strength, human development, ease of travel, political stability and overseas employment opportunities | Label: HasPart (ID: 12)
Ayelet Shachar → University of Toronto | Label: Employer (ID: 62)
Indian nationality → Senegalese nationality | Label: SharesBorderWith (ID: 45)
Senegalese → French | Label: NativeLanguage (ID: 50)


binge eating disorder → Weight stigma | Label: Causes (ID: 68)
binge eating disorder → Cognitive Behavioural Therapy (CBT) | Label: ApprovedBy (ID: 67)
binge eating disorder → Dialectical Behavioural Therapy (DBT) | Label: ApprovedBy (ID: 67)
Binge Eating → psychological issues | Label: InfluencedBy (ID: 14)
Binge Eating → Psychological issues | Label: InfluencedBy (ID: 14)
Binge eating disorder → Interpersonal Psychotherapy | Label: ApprovedBy (ID: 67)
Binge eating disorder → interpersonal psychotherapy (IPT) | Label: ApprovedBy (ID: 67)
Binge eating disorder → Interpersonal Psychotherapy (IPT) | Label: ApprovedBy (ID: 67)
Binge eating disorder → interpersonal psychotherapy | Label: ApprovedBy (ID: 67)
Binge eating disorder → Dialectical Behavioural Therapy | Label: ApprovedBy (ID: 67)
Binge eating disorder → Weight stigma | Label: Causes (ID: 68)
Binge eating disorder → Cognitive Behavioural Therapy (CBT) | Label: ApprovedBy (ID: 67)
Binge eating disorder → Dialectical Behavioural Th


Graph 58/74: 87 entities
Discrete trial training → practitioners of applied behavior analysis (ABA) | Label: UsedBy (ID: 3)
Discrete trial training → practitioners of applied behavior analysis | Label: UsedBy (ID: 3)
Discrete trial training (DTT) → practitioners of applied behavior analysis (ABA) | Label: UsedBy (ID: 3)
Discrete trial training (DTT) → practitioners of applied behavior analysis | Label: UsedBy (ID: 3)
DTT → listener responding | Label: SaidToBeTheSameAs (ID: 37)
DTT → rapid motor imitation antecedent | Label: SaidToBeTheSameAs (ID: 37)
DTT → errorless correction procedures | Label: HasPart (ID: 12)
DTT → mass trials | Label: SaidToBeTheSameAs (ID: 37)
DTT → Lovaas/UCLA model | Label: SaidToBeTheSameAs (ID: 37)
DTT → errorless learning | Label: SaidToBeTheSameAs (ID: 37)
DTT → Picture Exchange Communication System (PECS) | Label: Uses (ID: 17)
DTT → Picture Exchange Communication System | Label: Uses (ID: 17)
DTT → Charles Ferster | Label: ApprovedBy (ID: 67)
DTT → earl

Work-to-family conflict → family life | Label: HasEffect (ID: 39)
Loehr and Schwartz → WIF | Label: ContributedToCreativeWork (ID: 65)
Loehr → WIF | Label: ContributedToCreativeWork (ID: 65)

Graph 65/74: 71 entities
Mohammed Sani Musa → Bill | Label: OwnerOf (ID: 6)
Mohammed Sani Musa → Anti-social Media Bill | Label: OwnerOf (ID: 6)
Senator Mohammed Sani Musa → Bill | Label: OwnerOf (ID: 6)
Senator Mohammed Sani Musa → Anti-social Media Bill | Label: OwnerOf (ID: 6)
Bill → Protection from Internet Falsehood and Manipulations Bill 2019 | Label: HasPart (ID: 12)
Akon Eyakenyi → Akwa Ibom State | Label: Affiliation (ID: 7)
Anti-social Media Bill → Protection from Internet Falsehood and Manipulations Bill 2019 | Label: HasPart (ID: 12)

Graph 66/74: 28 entities
Cem Korkmaz → State Theatre of Bursa | Label: HasPart (ID: 12)
Cem Korkmaz → Kiraz Mevsimi | Label: HasWorksInTheCollection (ID: 15)
Cem Korkmaz → Diğer Yarım | Label: HasWorksInTheCollection (ID: 15)
Cem Korkmaz → Husband Factor 

  Continent                          : 2
  OfficialLanguage                   : 5
  PresentedIn                        : 3
  DifferentFrom                      : 13
  BasedOn                            : 30
  RegulatedBy                        : 3
  NamedAfter                         : 1
  SignificantEvent                   : 17
  WorkLocation                       : 5
  InterestedIn                       : 92
  AcademicDegree                     : 4
  CountryOfCitizenship               : 5
  Author                             : 5
  FollowedBy                         : 9
  InOppositionTo                     : 2
  EducatedAt                         : 5
  Follows                            : 22
  OwnedBy                            : 2
  Partner                            : 5
  SaidToBeTheSameAs                  : 61
  PhysicallyInteractsWith            : 1
  HasEffect                          : 145
  Founded                            : 12
  InspiredBy                         : 11
  Prac

In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from collections import defaultdict

# GNN Model
class GNN_RE_NER(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_ner_classes, num_re_classes):
        super().__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)

        self.ner_classifier = nn.Linear(hidden_dim, num_ner_classes)
        self.re_classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_re_classes)
        )

    def forward(self, x, edge_index, entity_pairs):
        x = self.gcn1(x, edge_index).relu()
        x = self.gcn2(x, edge_index).relu()

        # NER: predict tag for each token
        ner_logits = self.ner_classifier(x)

        # RE: predict relation between each candidate pair
        pair_reprs = []
        for head_idx, tail_idx in entity_pairs:
            h = x[head_idx]
            t = x[tail_idx]
            pair_reprs.append(torch.cat([h, t], dim=0))

        if pair_reprs:
            re_input = torch.stack(pair_reprs)
            re_logits = self.re_classifier(re_input)
        else:
            re_logits = torch.empty(0, self.re_classifier[-1].out_features).to(x.device)

        return ner_logits, re_logits


# Training Loop
def train_gnn_re_ner(graphs, ner_label_map, re_label_map, output_path,
                     entity_embed_dim=16, hidden_dim=128, 
                     max_epochs=100, patience=5, min_delta=1e-3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = GNN_RE_NER(
        input_dim=768 + entity_embed_dim,
        hidden_dim=hidden_dim,
        num_ner_classes=len(ner_label_map),
        num_re_classes=len(re_label_map)
    ).to(device)

    entity_embedder = nn.Embedding(len(ner_label_map), entity_embed_dim).to(device)

    ner_criterion = nn.CrossEntropyLoss()
    re_criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    best_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(1, max_epochs + 1):
        model.train()
        total_loss = 0.0
        total_re_pairs = 0

        for graph in graphs:
            node_feats = torch.tensor(graph["node_features"], dtype=torch.float).to(device)
            edge_index = torch.tensor(graph["edges"], dtype=torch.long).t().contiguous().to(device)

            # Create NER input tensor
            node_ner_tags = ["O"] * len(graph["nodes"])
            for ent in graph.get("entities", []):
                for i in range(ent["start"], ent["end"]):
                    if i < len(node_ner_tags):
                        node_ner_tags[i] = ent["label"]

            ner_ids = [ner_label_map.get(tag, 0) for tag in node_ner_tags]
            ner_tensor = torch.tensor(ner_ids, dtype=torch.long).to(device)
            ner_embeds = entity_embedder(ner_tensor)

            # Final GCN input
            x = torch.cat([node_feats, ner_embeds], dim=1)

            # Relation extraction setup
            entity_pairs = [(rel["head"], rel["tail"]) for rel in graph["relation_labels"]]
            re_labels = [re_label_map.get(rel["label"], 0) for rel in graph["relation_labels"]]
            re_labels_tensor = torch.tensor(re_labels, dtype=torch.long).to(device)

            # Forward pass
            ner_logits, re_logits = model(x, edge_index, entity_pairs)

            loss_ner = ner_criterion(ner_logits, ner_tensor)
            loss_re = re_criterion(re_logits, re_labels_tensor) if re_logits.size(0) > 0 else torch.tensor(0., device=device)

            loss = loss_ner + loss_re
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_re_pairs += re_logits.size(0)

        avg_loss = total_loss / len(graphs)
        print(f"Epoch {epoch} | Avg Loss: {avg_loss:.4f}")

        # Early stopping
        if best_loss - avg_loss > min_delta:
            best_loss = avg_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"No improvement in {patience} epochs. Stopping at epoch {epoch}.")
                break

    # Convert label maps
    if isinstance(ner_label_map, defaultdict):
        ner_label_map = dict(ner_label_map)
    if isinstance(re_label_map, defaultdict):
        re_label_map = dict(re_label_map)

    # Save model
    torch.save({
        "model_state_dict": model.state_dict(),
        "entity_embedder_state_dict": entity_embedder.state_dict(),
        "ner_label_map": ner_label_map,
        "relation_label_map": re_label_map
    }, output_path)

    print(f"\nModel saved to: {output_path}")
    return model, entity_embedder


# === Train and Save ===
model, entity_embedder = train_gnn_re_ner(
    graphs=all_graphs,
    ner_label_map=ENTITY_LABELS,
    re_label_map=relation_label_map,
    output_path="Your_Model_Path_Here.pth",  # Change this to your model save path
    max_epochs=100
)

Epoch 1 | Avg Loss: 4.7544
Epoch 2 | Avg Loss: 3.6105
Epoch 3 | Avg Loss: 3.2641
Epoch 4 | Avg Loss: 2.8901
Epoch 5 | Avg Loss: 2.5316
Epoch 6 | Avg Loss: 2.2491
Epoch 7 | Avg Loss: 1.9981
Epoch 8 | Avg Loss: 1.7069
Epoch 9 | Avg Loss: 1.4419
Epoch 10 | Avg Loss: 1.2352
Epoch 11 | Avg Loss: 1.0558
Epoch 12 | Avg Loss: 0.9538
Epoch 13 | Avg Loss: 0.8716
Epoch 14 | Avg Loss: 0.7949
Epoch 15 | Avg Loss: 0.7516
Epoch 16 | Avg Loss: 0.7381
Epoch 17 | Avg Loss: 0.6700
Epoch 18 | Avg Loss: 0.5707
Epoch 19 | Avg Loss: 0.5383
Epoch 20 | Avg Loss: 0.4928
Epoch 21 | Avg Loss: 0.5105
Epoch 22 | Avg Loss: 0.4860
Epoch 23 | Avg Loss: 0.4635
Epoch 24 | Avg Loss: 0.4350
Epoch 25 | Avg Loss: 0.3836
Epoch 26 | Avg Loss: 0.3862
Epoch 27 | Avg Loss: 0.3415
Epoch 28 | Avg Loss: 0.3249
Epoch 29 | Avg Loss: 0.3121
Epoch 30 | Avg Loss: 0.3001
Epoch 31 | Avg Loss: 0.3110
Epoch 32 | Avg Loss: 0.3058
Epoch 33 | Avg Loss: 0.2838
Epoch 34 | Avg Loss: 0.3180
Epoch 35 | Avg Loss: 0.3267
Epoch 36 | Avg Loss: 0.4057
E

In [None]:
import os
import json
import torch
import spacy
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch_geometric.nn import GCNConv
import torch.nn as nn

# Load base models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

nlp = spacy.load("en_core_web_sm")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModel.from_pretrained("bert-base-uncased").to(device)
bert_model.eval()

# Define model architecture
class GNN_RE_NER(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_ner_classes, num_re_classes):
        super().__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.ner_classifier = nn.Linear(hidden_dim, num_ner_classes)
        self.re_classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_re_classes)
        )

    def forward(self, x, edge_index, entity_pairs):
        x = self.gcn1(x, edge_index).relu()
        x = self.gcn2(x, edge_index).relu()
        ner_logits = self.ner_classifier(x)
        pair_reprs = [torch.cat([x[head], x[tail]], dim=0) for head, tail in entity_pairs]
        if pair_reprs:
            re_input = torch.stack(pair_reprs)
            re_logits = self.re_classifier(re_input)
        else:
            re_logits = torch.empty(0, self.re_classifier[-1].out_features).to(x.device)
        return ner_logits, re_logits

# Load trained model and mappings
checkpoint_path = "Your_Model_Path_Here.pth"  # Change this to your model load path
checkpoint = torch.load(checkpoint_path, map_location=device)

ner_label_map = checkpoint["ner_label_map"]
relation_label_map = checkpoint["relation_label_map"]
id_to_ner_label = {v: k for k, v in ner_label_map.items()}
relation_id_to_label = {v: k for k, v in relation_label_map.items()}

model = GNN_RE_NER(
    input_dim=768 + 16,
    hidden_dim=128,
    num_ner_classes=len(ner_label_map),
    num_re_classes=len(relation_label_map)
).to(device)

entity_embedder = nn.Embedding(len(ner_label_map), 16).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
entity_embedder.load_state_dict(checkpoint["entity_embedder_state_dict"])
model.eval()

# Utilities
def get_bert_embeddings(text):
    encoded = tokenizer(text, return_tensors="pt", return_offsets_mapping=True, truncation=True).to(device)
    offsets = encoded.pop("offset_mapping")[0].tolist()
    output = bert_model(**encoded)
    embeddings = output.last_hidden_state.squeeze(0)
    tokens = tokenizer.convert_ids_to_tokens(encoded["input_ids"][0])
    return tokens, embeddings, offsets

def align_tokens_to_bert(doc, tokens, embeddings, offsets):
    token_embeddings = []
    for token in doc:
        token_start = token.idx
        token_end = token.idx + len(token.text)
        matched = [embeddings[i] for i, (start, end) in enumerate(offsets)
                   if start >= token_start and end <= token_end]
        if matched:
            token_embeddings.append(torch.mean(torch.stack(matched), dim=0).tolist())
        else:
            token_embeddings.append(torch.zeros(embeddings.size(-1)).tolist())
    return token_embeddings

def group_tokens_into_entities(token_texts, token_labels, id_to_ner_label):
    entities = []
    i = 0
    while i < len(token_texts):
        label = id_to_ner_label.get(token_labels[i], "O")
        if label != "O":
            start = i
            entity_tokens = [token_texts[i]]
            while i + 1 < len(token_texts) and id_to_ner_label.get(token_labels[i+1], "O") == label:
                i += 1
                entity_tokens.append(token_texts[i])
            end = i + 1
            entity_text = " ".join(entity_tokens)
            entities.append({
                "text": entity_text,
                "label": label,
                "start": start,
                "end": end
            })
        i += 1
    return entities

# Annotate a single file
def annotate_file(file_path, output_path, confidence_threshold=0.6):
    with open(file_path, "r") as f:
        doc_entry = json.load(f)

    text = doc_entry.get("document", doc_entry.get("doc", ""))
    doc = nlp(text)
    token_texts = [token.text for token in doc]

    tokens, embeddings, offsets = get_bert_embeddings(text)
    features = align_tokens_to_bert(doc, tokens, embeddings, offsets)
    edges = [(token.head.i, token.i) for token in doc if token.head.i != token.i]

    node_feats_tensor = torch.tensor(features, dtype=torch.float).to(device)
    ner_ids = [ner_label_map.get("O", 0)] * len(token_texts)
    ner_tensor = torch.tensor(ner_ids, dtype=torch.long).to(device)
    ner_embeds = entity_embedder(ner_tensor)

    x = torch.cat([node_feats_tensor, ner_embeds], dim=1)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(device) if edges else torch.empty((2, 0), dtype=torch.long).to(device)

    with torch.no_grad():
        x_reps = model.gcn1(x, edge_index).relu()
        x_reps = model.gcn2(x_reps, edge_index).relu()
        ner_logits = model.ner_classifier(x_reps)
        predicted_ner_ids = ner_logits.argmax(dim=1).cpu().tolist()

    predicted_entities = group_tokens_into_entities(token_texts, predicted_ner_ids, id_to_ner_label)

    candidate_pairs = [(e1["start"], e2["start"]) for i, e1 in enumerate(predicted_entities)
                       for j, e2 in enumerate(predicted_entities) if i != j]

    predicted_triples = []
    if candidate_pairs:
        pair_reprs = [torch.cat([x_reps[h], x_reps[t]], dim=0) for h, t in candidate_pairs]
        re_input = torch.stack(pair_reprs)
        with torch.no_grad():
            re_logits = model.re_classifier(re_input)
            probs = F.softmax(re_logits, dim=1)
            confidences, re_pred_ids = probs.max(dim=1)

        for idx, (head_idx, tail_idx) in enumerate(candidate_pairs):
            rel_id = re_pred_ids[idx].item()
            conf = confidences[idx].item()
            rel_label = relation_id_to_label.get(rel_id, "Unknown")
            h_ent = next((e for e in predicted_entities if e["start"] == head_idx), None)
            t_ent = next((e for e in predicted_entities if e["start"] == tail_idx), None)

            if h_ent and t_ent:
                if rel_label != "NoRelation" and conf >= confidence_threshold:
                    predicted_triples.append({
                        "head_text": h_ent["text"],
                        "tail_text": t_ent["text"],
                        "label": rel_label
                    })

    annotated_output = {
        "document": text,
        "predicted_entities": predicted_entities,
        "predicted_triples": predicted_triples,
        "NER-label_set": doc_entry.get("NER-label_set", []),
        "RE_label_set": doc_entry.get("RE_label_set", []),
        "id": doc_entry.get("id", "")
    }

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w") as f_out:
        json.dump(annotated_output, f_out, indent=4)

    print(f"Saved: {output_path}")
    return annotated_output

# Loop over all subfolders
root_input = "Your_folder_Path_Here"  # Change this to your input folder path
root_output = "Your_folder_Path_Here"  # Change this to your input folder path

for subdir in os.listdir(root_input):
    input_folder = os.path.join(root_input, subdir)
    output_folder = os.path.join(root_output, subdir)

    if not os.path.isdir(input_folder):
        continue

    for file_name in os.listdir(input_folder):
        if file_name.endswith(".json"):
            input_path = os.path.join(input_folder, file_name)
            output_path = os.path.join(output_folder, file_name)
            annotate_file(input_path, output_path)

Using device: cuda
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Anthropocene_Working_Group.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Applied_history.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Bolognese_bell_ringing.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Cylinder_Audio_Archive.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Drug_education.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Environmental_studies.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Essay_on_a_Course_of_Liberal_Education_for_Civil_and_Active_Life.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Legal_archaeology.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Literary_nonsense.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Master_of_Environmental_Management.json
Saved: /scratch/vsetpal/results/test/Academic_disciplines/Philosophy_of_design.json
Saved

Saved: /scratch/vsetpal/results/test/Food_and_drink/Food_engineering.json
Saved: /scratch/vsetpal/results/test/Food_and_drink/Irreechaa.json
Saved: /scratch/vsetpal/results/test/Food_and_drink/Liquor.json
Saved: /scratch/vsetpal/results/test/Food_and_drink/Night_of_the_Radishes.json
Saved: /scratch/vsetpal/results/test/Food_and_drink/Pagophagia.json
Saved: /scratch/vsetpal/results/test/Food_and_drink/Wilderness-acquired_diarrhea.json
Saved: /scratch/vsetpal/results/test/Geography/Age_of_Sail.json
Saved: /scratch/vsetpal/results/test/Geography/Churbaierische_Atlas.json
Saved: /scratch/vsetpal/results/test/Geography/Five_themes_of_geography.json
Saved: /scratch/vsetpal/results/test/Geography/Four_traditions_of_geography.json
Saved: /scratch/vsetpal/results/test/Geography/Geography_Cup.json
Saved: /scratch/vsetpal/results/test/Geography/Jebel_Akhdar_(Libya).json
Saved: /scratch/vsetpal/results/test/Geography/Land_systems.json
Saved: /scratch/vsetpal/results/test/Geography/Rosgen_Stream_Cl

Saved: /scratch/vsetpal/results/test/Language/Patter.json
Saved: /scratch/vsetpal/results/test/Language/Radical_interpretation.json
Saved: /scratch/vsetpal/results/test/Language/Unicode_font.json
Saved: /scratch/vsetpal/results/test/Law/Carceral_feminism.json
Saved: /scratch/vsetpal/results/test/Law/Continuing_legal_education.json
Saved: /scratch/vsetpal/results/test/Law/European_Judicial_Network.json
Saved: /scratch/vsetpal/results/test/Law/Judicial_populism.json
Saved: /scratch/vsetpal/results/test/Law/Law_French.json
Saved: /scratch/vsetpal/results/test/Law/Legal_risk.json
Saved: /scratch/vsetpal/results/test/Law/Legal_status_of_psilocybin_mushrooms.json
Saved: /scratch/vsetpal/results/test/Law/New_York_Anti-Secession_Ordinance.json
Saved: /scratch/vsetpal/results/test/Law/The_Justice_of_Trajan_and_Herkinbald.json
Saved: /scratch/vsetpal/results/test/Law/Trigger_law.json
Saved: /scratch/vsetpal/results/test/Life/Anthropopithecus.json
Saved: /scratch/vsetpal/results/test/Life/Formose

Saved: /scratch/vsetpal/results/test/Technology/Geomatics.json
Saved: /scratch/vsetpal/results/test/Technology/History_of_timekeeping_devices_in_Egypt.json
Saved: /scratch/vsetpal/results/test/Technology/Wei-Ying_Ma.json
Saved: /scratch/vsetpal/results/test/Universe/Cosmology.json
Saved: /scratch/vsetpal/results/test/Universe/Global_brain.json
Saved: /scratch/vsetpal/results/test/Universe/Hardness.json
Saved: /scratch/vsetpal/results/test/Universe/International_communication.json
Saved: /scratch/vsetpal/results/test/Universe/Negative_energy.json


In [None]:
import os
import json
from tabulate import tabulate

def evaluate_annotated_file(annotated_json_path):
    with open(annotated_json_path, "r") as f:
        data = json.load(f)

    gold_ner_labels = set(data.get("NER-label_set", []))
    predicted_ner_labels = {ent["label"] for ent in data.get("predicted_entities", [])}
    tp_ner = len(gold_ner_labels & predicted_ner_labels)
    fp_ner = len(predicted_ner_labels - gold_ner_labels)
    fn_ner = len(gold_ner_labels - predicted_ner_labels)
    precision_ner = tp_ner / (tp_ner + fp_ner + 1e-8)
    recall_ner = tp_ner / (tp_ner + fn_ner + 1e-8)
    f1_ner = 2 * precision_ner * recall_ner / (precision_ner + recall_ner + 1e-8)

    gold_re_labels = set(data.get("RE_label_set", []))
    predicted_re_labels = {trip["label"] for trip in data.get("predicted_triples", [])}
    tp_re = len(gold_re_labels & predicted_re_labels)
    fp_re = len(predicted_re_labels - gold_re_labels)
    fn_re = len(gold_re_labels - predicted_re_labels)
    precision_re = tp_re / (tp_re + fp_re + 1e-8)
    recall_re = tp_re / (tp_re + fn_re + 1e-8)
    f1_re = 2 * precision_re * recall_re / (precision_re + recall_re + 1e-8)

    precision_combined = (precision_ner + precision_re) / 2
    recall_combined = (recall_ner + recall_re) / 2
    f1_combined = (f1_ner + f1_re) / 2

    return precision_combined, recall_combined, f1_combined

def evaluate_all_domains(root_folder_path):
    domain_results = []

    for dirpath, dirnames, filenames in os.walk(root_folder_path):
        json_files = [f for f in filenames if f.endswith(".json")]
        if not json_files:
            continue

        total_precision = total_recall = total_f1 = 0
        count = 0

        for fname in json_files:
            fpath = os.path.join(dirpath, fname)
            try:
                p, r, f1 = evaluate_annotated_file(fpath)
                total_precision += p
                total_recall += r
                total_f1 += f1
                count += 1
            except Exception as e:
                print()

        if count > 0:
            avg_precision = total_precision / count
            avg_recall = total_recall / count
            avg_f1 = total_f1 / count

            # Extract domain name as last folder name
            domain_name = os.path.basename(dirpath)
            domain_results.append([domain_name, f"{avg_precision:.4f}", f"{avg_recall:.4f}", f"{avg_f1:.4f}"])

    # Print all as table
    if domain_results:
        print("Evaluation Summary Across Domains")
        print(tabulate(domain_results, headers=["Domain", "Precision", "Recall", "F1 Score"], tablefmt="pretty"))
    else:
        print("No annotated JSON files found.")

# Run across all folders inside /scratch/vsetpal/results
evaluate_all_domains("Your_folder_Path_Here")  # Change this to your folder path

Evaluation Summary Across Domains
+----------------------+-----------+--------+----------+
|        Domain        | Precision | Recall | F1 Score |
+----------------------+-----------+--------+----------+
| Academic_disciplines |  0.7522   | 0.0882 |  0.1530  |
|       Business       |  0.6244   | 0.0899 |  0.1523  |
|    Communication     |  0.7561   | 0.1692 |  0.2499  |
|       Culture        |  0.8594   | 0.1013 |  0.1770  |
|       Economy        |  0.7764   | 0.1076 |  0.1796  |
|      Education       |  0.8071   | 0.1144 |  0.1896  |
|        Energy        |  0.6812   | 0.0644 |  0.1137  |
|     Engineering      |  0.7003   | 0.1070 |  0.1741  |
|    Entertainment     |  0.8136   | 0.1449 |  0.2373  |
|    Food_and_drink    |  0.7033   | 0.1113 |  0.1824  |
|      Geography       |  0.7226   | 0.1093 |  0.1791  |
|      Government      |  0.8080   | 0.1014 |  0.1760  |
|        Health        |  0.7875   | 0.1182 |  0.1954  |
|       History        |  0.8002   | 0.1555 |  0.2486 