In [20]:
from neo4j import GraphDatabase
from datasets import load_dataset
from tqdm.auto import tqdm

from utils.env_loader import load_project_env
import os

In [21]:
load_project_env()

In [22]:
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", os.getenv("NEO4J_PASSWORD"))
SCHEME_ID = "CPC"

In [23]:
driver = GraphDatabase.driver(URI, auth=AUTH)

In [7]:
UPSERT_CONCEPT = """
MERGE (c:Concept {schemeId:$schemeId, conceptId:$conceptId})
  ON CREATE SET c.title = $title,
                c.fullTitle = $fullTitle,
                c.createdAt = datetime(),
                c.fullTitleEmbedding = $embedding,
                c.pathKeys = $pathKeys
  ON MATCH  SET c.title = coalesce($title, c.title),
                c.fullTitle = coalesce($fullTitle, c.fullTitle),
                c.updatedAt = datetime(),
                c.fullTitleEmbedding = coalesce($embedding, c.fullTitleEmbedding),
                c.pathKeys = coalesce($pathKeys, c.pathKeys)
WITH c
MATCH (s:Scheme {schemeId:$schemeId})
MERGE (s)-[:HAS_CONCEPT]->(c)
"""

In [8]:
def iter_rows(streaming: bool):
    return load_dataset("mhurhangee/cpc-classifications-embeddings", split="train", streaming=streaming)


In [15]:
def update_nodes():
    """
    Updates embeddings (and other fields if provided) for Concept nodes.
    """
    ds_iter = iter_rows(True)

    with driver.session() as s, tqdm(ds_iter, desc="Updating nodes with embeddings", unit="rec") as pbar:
        for rec in pbar:
            s.run(
                UPSERT_CONCEPT,
                schemeId=SCHEME_ID,
                conceptId=rec["key"],
                title=rec.get("title"),
                fullTitle=rec.get("fullTitle"),
                pathKeys=[tp["key"] for tp in (rec.get("treePath") or [])],
                embedding=rec.get("fullTitleEmbedding")
            )

In [16]:
update_nodes()

Updating nodes with embeddings: 98220rec [09:31, 181.36rec/s]'(ReadTimeoutError("HTTPSConnectionPool(host='cas-bridge.xethub.hf.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: fe4b2b0a-11a8-4f9a-b3d2-f2634d574982)')' thrown while requesting GET https://huggingface.co/datasets/mhurhangee/cpc-classifications-embeddings/resolve/03a8c9005d04d208295702df63eed9e7ce72ce25/data/train-00003-of-00008.parquet
Retrying in 1s [Retry 1/5].
Updating nodes with embeddings: 261960rec [25:11, 173.27rec/s]


In [None]:
import time

INDEX_NAME = "concept_fulltitle_embedding_idx"

# Create index
with driver.session() as s:
    s.run(f"""
    CREATE VECTOR INDEX {INDEX_NAME} IF NOT EXISTS
    FOR (c:CPC) ON (c.fullTitleEmbedding)
    OPTIONS {{ indexConfig: {{ `vector.dimensions`: 1536, `vector.similarity_function`: "cosine" }} }}
    """)


[INFO] Index state: POPULATING
[INFO] Index state: ONLINE
[OK] Vector index 'cpc_fulltitle_embedding_idx' is ONLINE and ready to use.


In [24]:

# Poll until ONLINE
while True:
    with driver.session() as s:
        res = s.run(f"""
        SHOW INDEXES YIELD name, state
        WHERE name = '{INDEX_NAME}'
        RETURN state
        """).single()
        state = res["state"] if res else None

    print(f"[INFO] Index state: {state}")
    if state == "ONLINE":
        break

    time.sleep(5)  # wait 5 seconds before checking again

print(f"[OK] Vector index '{INDEX_NAME}' is ONLINE and ready to use.")


[INFO] Index state: ONLINE
[OK] Vector index 'cpc_fulltitle_embedding_idx' is ONLINE and ready to use.
