In [None]:
from neo4j import GraphDatabase
from sentence_transformers import SentenceTransformer
import numpy as np
import json

# Neo4j connection parameters
uri = "bolt://127.0.0.1:7687"
username = "neo4j"
password = ""  # Replace with your password

# Initialize Neo4j driver
driver = GraphDatabase.driver(uri, auth=(username, password))

# Load text embedding model - same as in original code
model = SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
def create_session_embedding(session, stream_descriptions=None):
    """Create text embedding for a session including optional stream descriptions"""
    # Include title, synopsis_stripped, and theatre__name as required
    base_text = (
        f"{session['title']} {session['synopsis_stripped']} {session['theatre__name']}"
    )

    # Add stream descriptions if provided
    if stream_descriptions and len(stream_descriptions) > 0:
        stream_desc_text = " ".join(stream_descriptions)
        text = f"{base_text} {stream_desc_text}"
    else:
        text = base_text

    # Return the embedding as a list (will be converted to JSON)
    return model.encode(text).tolist()

In [None]:
def compute_and_save_embeddings(tx):
    """Compute embeddings for all sessions and save them directly to the nodes"""
    # 1. Get all sessions data
    query = """
    MATCH (s)
    WHERE  s:Sessions_this_year OR s:Sessions_past_year
    RETURN s.session_id as session_id, s.title as title, 
           s.stream as stream, s.synopsis_stripped as synopsis_stripped,
           s.theatre__name as theatre__name, labels(s)[0] as type,
           CASE WHEN s.key_text IS NOT NULL THEN s.key_text ELSE '' END as key_text
    """
    sessions = tx.run(query).data()

    # 2. Fetch all stream descriptions once
    stream_query = """
    MATCH (s:Stream)
    RETURN s.stream as stream, s.description as description
    """
    stream_data = tx.run(stream_query).data()

    # Create a dictionary of stream descriptions for quick lookup
    stream_descriptions = {s["stream"].lower(): s["description"] for s in stream_data}

    # Track progress
    total_sessions = len(sessions)
    print(f"Processing embeddings for {total_sessions} sessions...")

    # Process sessions in batches to avoid memory issues
    batch_size = 100
    for i in range(0, total_sessions, batch_size):
        batch = sessions[i : i + batch_size]
        print(
            f"Processing batch {i//batch_size + 1}/{(total_sessions + batch_size - 1)//batch_size}"
        )

        for session in batch:
            # Process the stream field - split by semicolon and handle duplicates
            session_streams = []
            if session["stream"]:
                # Split the stream string and strip whitespace
                stream_list = [
                    stream.strip().lower() for stream in session["stream"].split(";")
                ]
                # Remove duplicates by converting to set and back to list
                stream_list = list(set(stream_list))

                # Get the streams for embedding context
                for stream in stream_list:
                    if stream in stream_descriptions:
                        session_streams.append(stream)

            # Create embedding with the session data and stream descriptions
            embedding = create_session_embedding(session, session_streams)

            # Save the embedding back to the node
            update_query = """
            MATCH (s)
            WHERE s.session_id = $session_id
            SET s.embedding = $embedding
            """
            tx.run(
                update_query,
                session_id=session["session_id"],
                embedding=json.dumps(embedding),
            )

    print("All session embeddings have been computed and saved.")
    return True

In [None]:
driver = GraphDatabase.driver(uri, auth=(username, password))
# Default values for properties
default_properties = {
    "Days_since_registration": "119",
    "Country": "UK",
    "Source": "BVA Key Stakeholders",
    "Email_domain": "effem.com",
    "assist_year_before": "1",
    "job_role": "NA",
    "what_type_does_your_practice_specialise_in": "NA",
    "organisation_type": "NA",
    "JobTitle": "NA",
}


def set_default_properties(tx, properties):
    query = """
    MATCH (n:Visitor_this_year)
    SET
    """ + ",\n    ".join(
        [f"n.{key} = COALESCE(n.{key}, $props.{key})" for key in properties.keys()]
    )

    tx.run(query, props=properties)


def set_p():
    try:
        driver = GraphDatabase.driver(uri, auth=(username, password))

        with driver.session() as session:
            session.execute_write(set_default_properties, default_properties)
            print(
                "Missing properties set to default values for all Visitor_this_year nodes."
            )
    finally:
        driver.close()


set_p()

In [None]:
def main():
    try:
        driver = GraphDatabase.driver(uri, auth=(username, password))
        with driver.session() as session:
            result = session.execute_write(compute_and_save_embeddings)
            if result:
                print("Successfully saved embeddings to session nodes.")
    finally:
        driver.close()

In [None]:
main()