In [None]:
import neo4j
from neo4j import GraphDatabase
import json
from dotenv import load_dotenv
import os

load_dotenv()

URI = os.getenv('NEO4J_URI')
USERNAME = os.getenv('NEO4J_USERNAME')
PASSWORD = os.getenv('NEO4J_PASSWORD')

RDF_PATH = '/home/nhutpham/Public/EpsteinGraphRAG/data/rdf_triples.json'
TAG_PATH = '/home/nhutpham/Public/EpsteinGraphRAG/data/tag_clusters.json'

driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
try:
    driver.verify_connectivity()
    print("✅ Successfully connected to Neo4j!")
except Exception as e:
    print(f"❌ Connection failed: {e}")

✅ Successfully connected to Neo4j!


In [2]:
def create_constraints(tx):
    constraints = [
        "CREATE CONSTRAINT entity_name IF NOT EXISTS FOR (e:Entity) REQUIRE e.name IS UNIQUE",
        "CREATE CONSTRAINT event_id IF NOT EXISTS FOR (e:Event) REQUIRE e.id IS UNIQUE",
        "CREATE CONSTRAINT document_id IF NOT EXISTS FOR (d:Document) REQUIRE d.doc_id IS UNIQUE",
        "CREATE CONSTRAINT tag_name IF NOT EXISTS FOR (t:Tag) REQUIRE t.name IS UNIQUE",
        "CREATE CONSTRAINT location_name IF NOT EXISTS FOR (l:Location) REQUIRE l.name IS UNIQUE",
    ]
    for c in constraints:
        tx.run(c)

In [3]:
def import_batch(tx, batch):
    query = """
    UNWIND $rows AS row

    MERGE (actor:Entity {name: row.actor})
    MERGE (target:Entity {name: row.target})

    MERGE (e:Event {id: row.id})
    SET e.action = row.action,
        e.sequence_order = row.sequence_order,
        e.timestamp = row.timestamp,
        e.explicit_topic = row.explicit_topic,
        e.implicit_topic = row.implicit_topic,
        e.created_at = datetime(row.created_at)

    MERGE (actor)-[:PARTICIPATED_IN {role:'actor'}]->(e)
    MERGE (target)-[:PARTICIPATED_IN {role:'target'}]->(e)

    MERGE (d:Document {doc_id: row.doc_id})
    MERGE (e)-[:FROM_DOCUMENT]->(d)

    FOREACH (_ IN CASE WHEN row.location IS NOT NULL THEN [1] ELSE [] END |
        MERGE (l:Location {name: row.location})
        MERGE (e)-[:LOCATED_IN]->(l)
    )

    FOREACH (tag IN row.tags |
        MERGE (t:Tag {name: tag})
        MERGE (e)-[:HAS_TAG]->(t)
    )
    """
    tx.run(query, rows=batch)

In [6]:
def main():
    with open(RDF_PATH, "r", encoding="utf-8") as f:
        rdf = json.load(f)
    with open(TAG_PATH, "r", encoding="utf-8") as f:
        tag = json.load(f)
    processed = []
    for row in rdf:
        tags_id = json.loads(row["top_cluster_ids"]) if row.get("top_cluster_ids") else []
        if tags_id : 
            tags = [tag[id]['name'] for id in tags_id]
        processed.append({
            "id": row["id"],
            "actor": row["actor"],
            "target": row["target"],
            "action": row["action"],
            "doc_id": row["doc_id"],
            "sequence_order": row.get("sequence_order"),
            "timestamp": row.get("timestamp"),
            "explicit_topic": row.get("explicit_topic"),
            "implicit_topic": row.get("implicit_topic"),
            "created_at": row["created_at"].replace(" ", "T"),
            "location": row.get("location"),
            "tags": tags,
        })
    BATCH_SIZE = 500

    with driver.session() as session:

        session.execute_write(create_constraints)

        for i in range(0, len(processed), BATCH_SIZE):
            batch = processed[i:i+BATCH_SIZE]
            session.execute_write(import_batch, batch)
            print(f"Imported {i + len(batch)} / {len(processed)}")

    driver.close()
    print("✅ Import completed.")

In [7]:
if __name__ == "__main__":
    main()

Imported 500 / 107030
Imported 1000 / 107030
Imported 1500 / 107030
Imported 2000 / 107030
Imported 2500 / 107030
Imported 3000 / 107030
Imported 3500 / 107030
Imported 4000 / 107030
Imported 4500 / 107030
Imported 5000 / 107030
Imported 5500 / 107030
Imported 6000 / 107030
Imported 6500 / 107030
Imported 7000 / 107030
Imported 7500 / 107030
Imported 8000 / 107030
Imported 8500 / 107030
Imported 9000 / 107030
Imported 9500 / 107030
Imported 10000 / 107030
Imported 10500 / 107030
Imported 11000 / 107030
Imported 11500 / 107030
Imported 12000 / 107030
Imported 12500 / 107030
Imported 13000 / 107030
Imported 13500 / 107030
Imported 14000 / 107030
Imported 14500 / 107030
Imported 15000 / 107030
Imported 15500 / 107030
Imported 16000 / 107030
Imported 16500 / 107030
Imported 17000 / 107030
Imported 17500 / 107030
Imported 18000 / 107030
Imported 18500 / 107030
Imported 19000 / 107030
Imported 19500 / 107030
Imported 20000 / 107030
Imported 20500 / 107030
Imported 21000 / 107030
Imported 215

In [3]:
from openai import OpenAI
from typing import List, Dict, Any

client = OpenAI()

def get_embedding(text, model="text-embedding-3-small"):
    text = text.replace("\n", " ")
    return client.embeddings.create(input = [text], model=model).data[0].embedding


def embed_texts(texts, model="text-embedding-3-small"):
    return [get_embedding(t, model) for t in texts]

In [7]:
def create_vector_index(tx):
    tx.run("""
        CREATE VECTOR INDEX event_embedding IF NOT EXISTS
        FOR (e:Event)
        ON (e.embedding)
        OPTIONS {indexConfig: {
            `vector.dimensions`: 1536,
            `vector.similarity_function`: 'cosine'
        }}
    """)
with driver.session() as session:
    session.execute_write(create_vector_index)

In [8]:
def build_event_text(record):
    parts = []

    if record["actor"] and record["action"] and record["target"]:
        parts.append(
            f"{record['actor']} {record['action']} {record['target']}."
        )

    if record.get("location"):
        parts.append(f"Location: {record['location']}.")

    if record.get("explicit_topic"):
        parts.append(f"Topic: {record['explicit_topic']}.")

    if record.get("implicit_topic"):
        parts.append(f"Context: {record['implicit_topic']}.")

    if record.get("tags"):
        parts.append("Tags: " + ", ".join(record["tags"]) + ".")

    return " ".join(parts)

In [None]:
def fetch_events(tx):
    query = """
    MATCH (actor:Entity)-[:PARTICIPATED_IN {role:'actor'}]->(e:Event)
    MATCH (target:Entity)-[:PARTICIPATED_IN {role:'target'}]->(e)
    OPTIONAL MATCH (e)-[:LOCATED_IN]->(l:Location)
    OPTIONAL MATCH (e)-[:HAS_TAG]->(t:Tag)

    RETURN e.id AS id,
           actor.name AS actor,
           e.action AS action,
           target.name AS target,
           l.name AS location,
           e.explicit_topic AS explicit_topic,
           e.implicit_topic AS implicit_topic,
           collect(t.name) AS tags
    """
    return list(tx.run(query))

In [10]:
def store_embedding(tx, event_id, vector):
    tx.run(
        """
        MATCH (e:Event {id:$id})
        SET e.embedding = $vector
        """,
        id=event_id,
        vector=vector
    )

In [17]:
def main():

    with driver.session() as session:
        records = session.execute_read(fetch_events)

    print(f"Found {len(records)} events without embeddings")
    BATCH_SIZE = 500
    for i in range(0, len(records), BATCH_SIZE):

        batch = records[i:i+BATCH_SIZE]

        texts = [
            build_event_text(dict(r))
            for r in batch
        ]

        vectors = embed_texts(texts)

        with driver.session() as session:
            for record, vector in zip(batch, vectors):
                session.execute_write(
                    store_embedding,
                    record["id"],
                    vector
                )

        print(f"Embedded {i + len(batch)} / {len(records)}")

    driver.close()
    print("✅ All embeddings stored.")


if __name__ == "__main__":
    main()

Found 107030 events without embeddings
Embedded 500 / 107030
Embedded 1000 / 107030
Embedded 1500 / 107030
Embedded 2000 / 107030
Embedded 2500 / 107030
Embedded 3000 / 107030
Embedded 3500 / 107030
Embedded 4000 / 107030
Embedded 4500 / 107030
Embedded 5000 / 107030
Embedded 5500 / 107030
Embedded 6000 / 107030
Embedded 6500 / 107030
Embedded 7000 / 107030
Embedded 7500 / 107030
Embedded 8000 / 107030
Embedded 8500 / 107030
Embedded 9000 / 107030
Embedded 9500 / 107030
Embedded 10000 / 107030
Embedded 10500 / 107030
Embedded 11000 / 107030
Embedded 11500 / 107030
Embedded 12000 / 107030
Embedded 12500 / 107030
Embedded 13000 / 107030
Embedded 13500 / 107030
Embedded 14000 / 107030
Embedded 14500 / 107030
Embedded 15000 / 107030
Embedded 15500 / 107030
Embedded 16000 / 107030
Embedded 16500 / 107030
Embedded 17000 / 107030
Embedded 17500 / 107030
Embedded 18000 / 107030
Embedded 18500 / 107030
Embedded 19000 / 107030
Embedded 19500 / 107030
Embedded 20000 / 107030
Embedded 20500 / 1070