In [22]:
from neo4j import GraphDatabase
from datasets import load_dataset
from tqdm.auto import tqdm

from utils.env_loader import load_project_env
import os
import re

In [2]:
load_project_env()

In [3]:
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", os.getenv("NEO4J_PASSWORD"))
SCHEME_ID = "CPC"

In [4]:
STREAMING = True           # keeps memory tiny
COUNT_TOTAL = False        # set True to compute exact totals (does a cached local pass)

In [5]:
driver = GraphDatabase.driver(URI, auth=AUTH)

In [6]:
DDL = [
    """CREATE CONSTRAINT scheme_id IF NOT EXISTS
       FOR (s:Scheme) REQUIRE s.schemeId IS UNIQUE""",
    """CREATE CONSTRAINT concept_identity IF NOT EXISTS
       FOR (c:Concept) REQUIRE (c.schemeId, c.conceptId) IS UNIQUE"""
]

ENSURE_SCHEME = """
MERGE (s:Scheme {schemeId: $schemeId})
  ON CREATE SET s.name = $name, s.createdAt = datetime()
"""

UPSERT_CONCEPT = """
MERGE (c:Concept {schemeId:$schemeId, conceptId:$conceptId})
  ON CREATE SET c.title=$title, c.fullTitle=$fullTitle, c.createdAt=datetime(),
                c.pathKeys=$pathKeys
  ON MATCH  SET c.title=$title, c.fullTitle=$fullTitle, c.updatedAt=datetime(),
                c.pathKeys=$pathKeys
WITH c
MATCH (s:Scheme {schemeId:$schemeId})
MERGE (s)-[:HAS_CONCEPT]->(c)
"""

LINK_PARENT = """
MATCH (child:Concept {schemeId:$schemeId, conceptId:$child})
MATCH (parent:Concept {schemeId:$schemeId, conceptId:$parent})
MERGE (child)-[:BROADER]->(parent)
"""

LINK_REFERS_TO = """
MATCH (src:Concept {schemeId:$schemeId, conceptId:$src})
MATCH (dst:Concept {schemeId:$schemeId, conceptId:$dst})
MERGE (src)-[:REFERS_TO]->(dst)
"""

In [7]:
def run_tx(tx, cypher, **params):
    # small helper for readability
    tx.run(cypher, **params)

In [8]:
def ensure_constraints_and_scheme():
    with driver.session() as s:
        for stmt in DDL:
            s.run(stmt)
        s.run(ENSURE_SCHEME, schemeId=SCHEME_ID,
              name="Cooperative Patent Classification (MVP)")

In [9]:
def get_total_rows() -> int | None:
    """
    Optionally fetch the exact number of rows for pretty progress bars.
    This does a non-streaming load (cached locally by HF Datasets).
    """
    if not COUNT_TOTAL:
        return None
    ds = load_dataset("mhurhangee/cpc-classifications", split="train", streaming=False)
    try:
        return ds.num_rows  # fast metadata path
    except Exception:
        try:
            return len(ds)   # fallback
        except Exception:
            return None

In [10]:
def iter_rows(streaming: bool):
    # Streaming iterator: tiny memory footprint; re-iterable for separate passes
    return load_dataset("mhurhangee/cpc-classifications", split="train", streaming=streaming)


In [None]:
def pass1_nodes():
    """
    PASS 1: Upsert all nodes (ancestors from treePath first, then record node).
    Using per-record writes for simplicity; idempotent MERGEs keep it safe.
    """
    total = get_total_rows()
    ds_iter = iter_rows(STREAMING)

    with driver.session() as s, tqdm(ds_iter, total=total, desc="Pass 1/2: upserting nodes", unit="rec") as pbar:
        for rec in pbar:
            # Upsert the node itself
            s.run(
                UPSERT_CONCEPT,
                schemeId=SCHEME_ID,
                conceptId=rec["key"],
                title=rec.get("title"),
                fullTitle=rec.get("fullTitle"),
                pathKeys=[tp["key"] for tp in (rec.get("treePath") or [])]
            )

In [12]:
def pass2_rels():
    """
    PASS 2: Create relationships.
    - BROADER edges (explicit 'broader' + stitch linear treePath)
    - REFERS_TO edges (only if dst exists; uses MATCH to keep MVP clean)
    """
    total = get_total_rows()
    ds_iter = iter_rows(STREAMING)

    with driver.session() as s, tqdm(ds_iter, total=total, desc="Pass 2/2: linking rels", unit="rec") as pbar:
        for rec in pbar:
            child = rec["key"]

            # Explicit parents from 'broader'
            for parent in (rec.get("broader") or []):
                s.run(LINK_PARENT, schemeId=SCHEME_ID, child=child, parent=parent)

            # Stitch the ordered treePath chain (child -> immediate parent)
            tp_keys = [tp["key"] for tp in (rec.get("treePath") or [])]
            for i in range(1, len(tp_keys)):
                s.run(LINK_PARENT, schemeId=SCHEME_ID, child=tp_keys[i], parent=tp_keys[i-1])

            # Cross-references (skip silently if target missing)
            for dst in (rec.get("references") or []):
                s.run(LINK_REFERS_TO, schemeId=SCHEME_ID, src=child, dst=dst)

In [13]:
ensure_constraints_and_scheme()

In [16]:
pass1_nodes()
pass2_rels()


Pass 1/2: upserting nodes: 261962rec [09:49, 444.39rec/s]
Pass 2/2: linking rels: 261962rec [44:27, 98.20rec/s] 


In [36]:
Q_COUNTS = """
MATCH (c:Concept {schemeId:$scheme}) RETURN count(c) AS nodes;
"""

Q_REL_COUNTS = """
MATCH (:Concept {schemeId:$scheme})-[r:BROADER]->(:Concept {schemeId:$scheme}) RETURN count(r) AS broader;
"""

Q_REF_COUNTS = """
MATCH (:Concept {schemeId:$scheme})-[r:REFERS_TO]->(:Concept {schemeId:$scheme}) RETURN count(r) AS refers;
"""

Q_ROOTS = """
MATCH (c:Concept {schemeId:$scheme})
WHERE NOT (c)-[:BROADER]->(:Concept {schemeId:$scheme})
RETURN count(c) AS roots, collect(c.conceptId)[0..$k] AS sample;
"""

Q_LEAFS = """
MATCH (p:Concept {schemeId:$scheme})
WHERE NOT (:Concept {schemeId:$scheme})-[:BROADER]->(p)
RETURN count(p) AS leafs, collect(p.conceptId)[0..$k] AS sample;
"""

Q_CYCLE = """
MATCH (n:Concept {schemeId:$scheme})
MATCH p=(n)-[:BROADER*1..]->(n)
RETURN n.conceptId AS node, length(p) AS len
LIMIT 1;
"""

Q_MAX_DEPTH = """
// Max child->root path length (rough sense of tree height)
MATCH (c:Concept {schemeId:$scheme})
OPTIONAL MATCH p=(c)-[:BROADER*0..]->(r:Concept {schemeId:$scheme})
WHERE NOT (r)-[:BROADER]->(:Concept {schemeId:$scheme})
RETURN coalesce(max(length(p)),0) AS maxDepth;
"""

Q_ROOT_BREAKDOWN = """
MATCH (c:Concept {schemeId:$scheme})
WHERE NOT (c)-[:BROADER]->(:Concept {schemeId:$scheme})
WITH c,
  CASE
    WHEN c.conceptId =~ '^[A-HY]$' THEN 'SECTION'
    WHEN c.conceptId ENDS WITH '-00-GH' THEN 'GENERAL_HOLDING'
    ELSE 'OTHER'
  END AS tag
RETURN tag, count(*) AS n
ORDER BY n DESC
"""

Q_OTHER_ROOTS = """
MATCH (c:Concept {schemeId:$scheme})
WHERE NOT (c)-[:BROADER]->(:Concept {schemeId:$scheme})
  AND NOT (c.conceptId =~ '^[A-HY]$' OR c.conceptId ENDS WITH '-00-GH')
RETURN c.conceptId AS id, c.title AS title
"""

In [27]:
def _chunks(it, size):
    it = iter(it)
    while True:
        chunk = list(islice(it, size))
        if not chunk:
            return
        yield chunk

def _verify_declared_edges(sample=10, batch=1000):
    """
    Streams the HF dataset and checks that each declared pair exists:
      - broader: (child)-[:BROADER]->(parent)
      - references: (src)-[:REFERS_TO]->(dst)
    Reports total missing edges and small samples.
    """
    ds = load_dataset("mhurhangee/cpc-classifications", split="train", streaming=True)

    broader_missing = 0
    refers_missing = 0
    sample_broader = []
    sample_refers = []

    buf_b = []   # (child, parent)
    buf_r = []   # (src, dst)

    for rec in ds:
        c = rec["key"]
        for p in (rec.get("broader") or []):
            buf_b.append({"child": c, "parent": p})
        for d in (rec.get("references") or []):
            buf_r.append({"src": c, "dst": d})

        # Check in batches to keep memory low
        if len(buf_b) >= batch:
            mb, sb = _missing_broader(buf_b, sample - len(sample_broader))
            broader_missing += mb
            sample_broader.extend(sb)
            buf_b.clear()

        if len(buf_r) >= batch:
            mr, sr = _missing_refers(buf_r, sample - len(sample_refers))
            refers_missing += mr
            sample_refers.extend(sr)
            buf_r.clear()

    # Flush remaining
    if buf_b:
        mb, sb = _missing_broader(buf_b, sample - len(sample_broader))
        broader_missing += mb
        sample_broader.extend(sb)
    if buf_r:
        mr, sr = _missing_refers(buf_r, sample - len(sample_refers))
        refers_missing += mr
        sample_refers.extend(sr)

    print("\n=== Declared edges — verification ===")
    print(f"Missing BROADER edges : {broader_missing:,}  sample={sample_broader}")
    print(f"Missing REFERS_TO edges: {refers_missing:,}  sample={sample_refers}")

def _missing_broader(pairs, want_sample):
    """
    Returns (missing_count, sample_list) for BROADER pairs.
    Only counts edges missing when BOTH nodes exist; if a node is absent, we skip (edge can’t exist).
    """
    q = """
    UNWIND $pairs AS pair
    MATCH (child:Concept {schemeId:$scheme, conceptId: pair.child})
    MATCH (parent:Concept {schemeId:$scheme, conceptId: pair.parent})
    WITH child, parent, pair
    WHERE NOT (child)-[:BROADER]->(parent)
    RETURN count(*) AS missing, collect(pair)[0..$k] AS sample;
    """
    with driver.session() as s:
        res = s.run(q, pairs=pairs, scheme=SCHEME_ID, k=max(0, want_sample)).single()
        return res["missing"], res["sample"]

def _missing_refers(pairs, want_sample):
    """
    Returns (missing_count, sample_list) for REFERS_TO pairs.
    """
    q = """
    UNWIND $pairs AS pair
    MATCH (src:Concept {schemeId:$scheme, conceptId: pair.src})
    MATCH (dst:Concept {schemeId:$scheme, conceptId: pair.dst})
    WITH src, dst, pair
    WHERE NOT (src)-[:REFERS_TO]->(dst)
    RETURN count(*) AS missing, collect(pair)[0..$k] AS sample;
    """
    with driver.session() as s:
        res = s.run(q, pairs=pairs, scheme=SCHEME_ID, k=max(0, want_sample)).single()
        return res["missing"], res["sample"]

In [37]:
def tag_expected_root(concept_id: str) -> str:
    """
    Returns a tag for known/expected root types.
    """
    if re.fullmatch(r"[A-HY]", concept_id):  # section letters (A–H, Y)
        return "SECTION"
    if concept_id.endswith("-00-GH"):        # general holding group codes
        return "GENERAL_HOLDING"
    # Add other patterns here as needed
    return "OTHER"

In [38]:
def diagnose_graph(
    check_declared: bool = False,
    sample: int = 10,
    show_degree_heads: bool = False,
    degree_top_n: int = 10,
):
    """
    Lightweight by default:
      - node/edge counts
      - roots & leafs (with tiny samples)
      - cycle presence
      - max depth estimate

    Optional:
      - check_declared=True → verify dataset-declared BROADER/REFERS_TO pairs exist
      - show_degree_heads=True → show parents with most children (quick fan-out sense)
    """
    with driver.session() as s:
        nodes = s.run(Q_COUNTS, scheme=SCHEME_ID).single()["nodes"]
        broader = s.run(Q_REL_COUNTS, scheme=SCHEME_ID).single()["broader"]
        refers = s.run(Q_REF_COUNTS, scheme=SCHEME_ID).single()["refers"]
        roots = s.run(Q_ROOTS, scheme=SCHEME_ID, k=sample).single()
        tagged_roots = [(rid, tag_expected_root(rid)) for rid in roots["sample"]]
        num_expected = sum(1 for _, t in tagged_roots if t != "OTHER")
        num_other = sum(1 for _, t in tagged_roots if t == "OTHER")
        leafs = s.run(Q_LEAFS, scheme=SCHEME_ID, k=sample).single()
        cyc = s.run(Q_CYCLE, scheme=SCHEME_ID).single()
        max_depth = s.run(Q_MAX_DEPTH, scheme=SCHEME_ID).single()["maxDepth"]
        rows = s.run(Q_ROOT_BREAKDOWN, scheme=SCHEME_ID).data()
        total_roots = sum(r["n"] for r in rows)
        others = s.run(Q_OTHER_ROOTS, scheme=SCHEME_ID).data()


        print("\n=== CPC Graph — sanity check ===")
        print(f"Concept nodes     : {nodes:,}")
        print(f"BROADER relations : {broader:,}")
        print(f"REFERS_TO relations: {refers:,}")
        print(f"Roots (no parent) : {roots['roots']:,}")
        print(f"Roots (no parent) : {total_roots:,}")
        for r in rows:
            print(f"  {r['tag']:<16}: {r['n']:,}")
        if others:
            print("\nUnexpected root nodes (tag=OTHER):")
            for o in others:
                print(f"  {o['id']}: {o['title']}")
        else:
            print("\nNo unexpected root nodes found.")
        
        print(f"  Expected special cases: {num_expected}, Unexpected: {num_other}")
        print(f"  Sample (id, tag): {tagged_roots}")
        print(f"Leafs (no children): {leafs['leafs']:,}  sample={leafs['sample']}")
        print(f"Max depth (child→root): {max_depth}")
        if cyc:
            print(f"⚠️ Cycle detected via {cyc['node']} (path len {cyc['len']})")
        else:
            print("✅ No cycles detected in BROADER edges.")

        if show_degree_heads:
            q = """
            MATCH (parent:Concept {schemeId:$scheme})<-[:BROADER]-(:Concept {schemeId:$scheme})
            WITH parent, count(*) AS kids
            ORDER BY kids DESC
            RETURN parent.conceptId AS parent, kids
            LIMIT $n
            """
            rows = s.run(q, scheme=SCHEME_ID, n=degree_top_n).data()
            print("\nTop parents by children:")
            for r in rows:
                print(f"  {r['parent']}: {r['kids']}")

    if check_declared:
        _verify_declared_edges(sample=sample)

In [39]:
diagnose_graph(check_declared=False, show_degree_heads=False)


=== CPC Graph — sanity check ===
Concept nodes     : 261,962
BROADER relations : 261,016
REFERS_TO relations: 14,241
Roots (no parent) : 946
Roots (no parent) : 946
  GENERAL_HOLDING : 936
  SECTION         : 9
  OTHER           : 1

Unexpected root nodes (tag=OTHER):
  scheme: scheme
  Expected special cases: 10, Unexpected: 0
  Sample (id, tag): [('Y10S336-00-GH', 'GENERAL_HOLDING'), ('A63C1-00-GH', 'GENERAL_HOLDING'), ('D06F1-00-GH', 'GENERAL_HOLDING'), ('Y10S122-00-GH', 'GENERAL_HOLDING'), ('A22C25-00-GH', 'GENERAL_HOLDING'), ('Y10S600-00-GH', 'GENERAL_HOLDING'), ('F02P19-00-GH', 'GENERAL_HOLDING'), ('C07J1-00-GH', 'GENERAL_HOLDING'), ('A63B27-00-GH', 'GENERAL_HOLDING'), ('D04B31-00-GH', 'GENERAL_HOLDING')]
Leafs (no children): 187,587  sample=['C07C239-22', 'H01L2224-8056', 'B60R21-01524', 'C14C3-12', 'A23C2220-208', 'B60J7-192', 'F24F1-0038', 'E06C9-12', 'B66B1-468', 'C12Y201-0121']
Max depth (child→root): 16
✅ No cycles detected in BROADER edges.


In [21]:
diagnose_graph(check_declared=True, show_degree_heads=True)


=== CPC Graph — sanity check ===
Concept nodes     : 261,962
BROADER relations : 261,016
REFERS_TO relations: 14,241
Roots (no parent) : 946  sample=['Y10S336-00-GH', 'A63C1-00-GH', 'D06F1-00-GH', 'Y10S122-00-GH', 'A22C25-00-GH', 'Y10S600-00-GH', 'F02P19-00-GH', 'C07J1-00-GH', 'A63B27-00-GH', 'D04B31-00-GH']
Leafs (no children): 187,587  sample=['C07C239-22', 'H01L2224-8056', 'B60R21-01524', 'C14C3-12', 'A23C2220-208', 'B60J7-192', 'F24F1-0038', 'E06C9-12', 'B66B1-468', 'C12Y201-0121']
Max depth (child→root): 16
✅ No cycles detected in BROADER edges.

Top parents by children:
  G05B2219-37: 571
  G05B2219-40: 567
  G05B2219-36: 533
  G05B2219-35: 530
  G05B2219-39: 521
  G05B2219-34: 445
  H04Q2213-00: 445
  G05B2219-31: 437
  G05B2219-25: 436
  G05B2219-41: 433

=== Declared edges — verification ===
Missing BROADER edges : 0  sample=[]
Missing REFERS_TO edges: 0  sample=[]


In [45]:
Q_FIND_UNEXPECTED_ROOTS = """
MATCH (c:Concept {schemeId:$scheme})
WHERE NOT (c)-[:BROADER]->(:Concept {schemeId:$scheme})
  AND NOT (c.conceptId =~ '^[A-HY]$' OR c.conceptId ENDS WITH '-00-GH')
RETURN c.conceptId AS id, c.title AS title, labels(c) AS labels,
       c.createdAt AS createdAt, c.updatedAt AS updatedAt
ORDER BY id
"""

Q_NODE_WITH_DEGREE = """
MATCH (c:Concept {schemeId:$scheme, conceptId:$id})
OPTIONAL MATCH (c)-[r]-()
RETURN c.conceptId AS id, c.title AS title, labels(c) AS labels, count(r) AS degree
"""

Q_DELETE_IDS = """
UNWIND $ids AS cid
MATCH (c:Concept {schemeId:$scheme, conceptId:cid})
DETACH DELETE c
"""