In [None]:
# Import required libraries
from neo4j import GraphDatabase
from sentence_transformers import SentenceTransformer
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

uri = "bolt://127.0.0.1:7687"
username = "neo4j"
password = ""  # Replace with your password.

# Initialize Neo4j driver
# driver = GraphDatabase.driver(uri, auth=(username, password))

# Load a good text embedding model
model = SentenceTransformer("all-MiniLM-L6-v2")
# model = SentenceTransformer(
#     "nomic-ai/nomic-embed-text-v1", trust_remote_code=True,model_kwargs = {"weights_only":True}
# )  # , device="cuda")

In [None]:
text = """Urology in veterinary medicine focuses on the diagnosis, treatment, and management of diseases and conditions affecting the urinary tract and kidneys in animals. 
This category encompasses a wide range of topics including proteinuria interpretation, urinary tract infections, obstructive conditions like blocked bladder, and surgical
interventions for urolithiasis. It also highlights the importance of diagnostic imaging techniques and the critical role of veterinary nurses in managing chronic kidney disease. 
Overall, urology integrates medical, surgical, and supportive care approaches to maintain and restore urinary health in various animal species."""
len(model.encode(text))

In [None]:
# Default values for properties
default_properties = {
    "Days_since_registration": "119",
    "Country": "UK",
    "Source": "BVA Key Stakeholders",
    "Email_domain": "effem.com",
    "assist_year_before": "1",
    "job_role": "NA",
    "what_type_does_your_practice_specialise_in": "NA",
    "organisation_type": "NA",
    "JobTitle": "NA",
}


def set_default_properties(tx, properties):
    query = """
    MATCH (n:Visitor_this_year)
    SET
    """ + ",\n    ".join(
        [f"n.{key} = COALESCE(n.{key}, $props.{key})" for key in properties.keys()]
    )

    tx.run(query, props=properties)


driver = GraphDatabase.driver(uri, auth=(username, password))


def main():

    with driver.session() as session:
        session.execute_write(set_default_properties, default_properties)
    driver.close()
    print("Missing properties set to default values for all Visitor_this_year nodes.")

In [None]:
main()

In [None]:
from neo4j import GraphDatabase
import time


def get_db_session():
    """Get a Neo4j session from the global driver."""
    global neo4j_driver
    try:
        # Verify connection is still valid
        if neo4j_driver.verify_connectivity():
            return neo4j_driver.session()
        else:
            # Reinitialize if connection is invalid
            neo4j_driver = GraphDatabase.driver(
                "neo4j://localhost:7687", auth=(username, password)
            )
            return neo4j_driver.session()
    except Exception as e:
        print(f"Error getting database session: {e}")
        # Reinitialize driver
        neo4j_driver = GraphDatabase.driver(
            "neo4j://localhost:7687", auth=(username, password)
        )
        return neo4j_driver.session()

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import time
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from neo4j import GraphDatabase
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
import matplotlib.pyplot as plt
import seaborn as sns

# Display settings for better visualization
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", 20)
pd.set_option("display.width", 1000)

# 1. Connect to Neo4j database
# ------------------------------
# Replace with your actual connection details
uri = "bolt://localhost:7687"
username = "neo4j"
password = ""

# Create the driver
driver = GraphDatabase.driver(uri, auth=(username, password))


# Test connection
def test_connection():
    with driver.session() as session:
        result = session.run("MATCH (n) RETURN count(n) as count").single()
        if result:
            print(f"✅ Successfully connected to Neo4j - Found {result['count']} nodes")
            return True
        else:
            print("❌ Failed to connect to Neo4j")
            return False


test_connection()

# 2. Import the optimized code
# ----------------------------
# Import the code from the Python file
# You should first save the code from the previous artifact as a Python file
# Note: For this notebook's sake, I'm pasting key functions directly in the next cell

# Define job role categories
VET_ROLES = [
    "Vet/Vet Surgeon",
    "Assistant Vet",
    "Vet/Owner",
    "Clinical or other Director",
    "Locum Vet",
    "Academic",
]

NURSE_ROLES = ["Head Nurse/Senior Nurse", "Vet Nurse", "Locum RVN"]

BUSINESS = ["Practice Manager", "Practice Partner/Owner"]
# Other roles can attend any session
OTHER_ROLES = ["Student", "Receptionist", "Other (please specify)"]


# Global cache for reuse of data
class RecommendationCache:
    def __init__(self):
        self.session_embeddings = None
        self.all_sessions_data = None
        self.stream_descriptions = None
        self.visitor_info_cache = {}
        self.past_sessions_cache = {}
        self.similar_visitors_cache = {}
        self.filtered_sessions_cache = {}

    def is_initialized(self):
        return self.session_embeddings is not None


cache = RecommendationCache()

# 3. Load and initialize the sentence transformer model
# -----------------------------------------------------
print("Loading sentence transformer model...")
model = SentenceTransformer("all-MiniLM-L6-v2")  # Using a smaller, faster model
print("Model loaded successfully!")

# 4. Load the visitor data
# -----------------------
# Replace with your actual file path
csv_file_path = "visitor_data.csv"

# Load data
try:
    data = pd.read_csv(csv_file_path)
    print(f"✅ Successfully loaded CSV with {len(data)} rows")
    # Display sample of data
    display(data.head())
    # Get unique badge IDs
    list_badgeId_this = list(data["BadgeId"].unique())
    print(f"Found {len(list_badgeId_this)} unique badge IDs")
except FileNotFoundError:
    print(f"❌ File not found: {csv_file_path}")
    # Create some dummy data for demonstration
    print("Creating dummy data for demonstration")
    list_badgeId_this = [f"BADGE{i}" for i in range(1, 101)]
    print(f"Created {len(list_badgeId_this)} dummy badge IDs")


# 5. Precompute all session embeddings
# ------------------------------------
def precompute_all_data(tx):
    """Precompute all data needed for recommendations in a single database transaction."""
    start_time = time.time()
    print("Precomputing all session data...")

    # 1. Fetch all sessions in a single query
    sessions_query = """
    MATCH (s)
    WHERE s:Sessions_past_year_dva OR s:Sessions_past_year_lva OR s:Sessions_this_year OR s:Sessions_past_year
    RETURN s.session_id as session_id, s.title as title, 
           s.stream as stream, s.synopsis_stripped as synopsis_stripped,
           s.theatre__name as theatre__name, s.sponsored_by as sponsored_by,
           s.sponsored_session as sponsored_session, labels(s)[0] as type,
           CASE WHEN s.key_text IS NOT NULL THEN s.key_text ELSE '' END as key_text
    """
    all_sessions = tx.run(sessions_query).data()
    print(
        f"Fetched {len(all_sessions)} sessions from database in {time.time() - start_time:.2f}s"
    )

    # 2. Fetch all stream descriptions at once
    stream_query = """
    MATCH (s:Stream)
    RETURN s.stream as stream, s.description as description
    """
    stream_data = tx.run(stream_query).data()

    # Create a dictionary of stream descriptions for quick lookup
    stream_descriptions = {s["stream"].lower(): s["description"] for s in stream_data}
    cache.stream_descriptions = stream_descriptions

    # Process all sessions at once - this is CPU-bound, so we'll use batching
    print("Computing embeddings for all sessions...")
    batch_size = 100  # Process in smaller batches to avoid memory issues
    embeddings = {}

    # Store raw session data for quick lookup later
    cache.all_sessions_data = {s["session_id"]: s for s in all_sessions}

    for i in range(0, len(all_sessions), batch_size):
        batch = all_sessions[i : i + batch_size]
        batch_embeddings = {}

        for s in batch:
            # Process the stream field - split by semicolon and handle duplicates
            session_streams = []
            if s["stream"]:
                # Split the stream string and strip whitespace
                stream_list = [
                    stream.strip().lower() for stream in s["stream"].split(";")
                ]
                # Remove duplicates by converting to set and back to list
                stream_list = list(set(stream_list))

                # Get the description for each stream
                for stream in stream_list:
                    if stream in stream_descriptions:
                        session_streams.append(stream)

            # Create session text for embedding
            base_text = f"{s['title']} {s['synopsis_stripped']} {s['theatre__name']}"
            if session_streams and len(session_streams) > 0:
                stream_text = " ".join(session_streams)
                text = f"{base_text} {stream_text}"
            else:
                text = base_text

            # Add to batch for embedding
            batch_embeddings[s["session_id"]] = {
                "text": text,
                "type": s["type"],
                "theatre__name": s["theatre__name"],
                "stream": s["stream"],
                "title": s["title"],
                "sponsored_by": (
                    s["sponsored_by"] if "sponsored_by" in s else "Not Sponsored"
                ),
                "sponsored_session": (
                    s["sponsored_session"] if "sponsored_session" in s else "False"
                ),
                "key_text": s["key_text"] if "key_text" in s else "",
            }

        # Compute embeddings for entire batch at once
        texts = [data["text"] for _, data in batch_embeddings.items()]
        batch_vectors = model.encode(texts, show_progress_bar=False)

        # Store embeddings
        for (session_id, data), vector in zip(batch_embeddings.items(), batch_vectors):
            data["embedding"] = vector
            embeddings[session_id] = data

        if (i + batch_size) % 500 == 0 or (i + batch_size) >= len(all_sessions):
            print(
                f"Processed {min(i + batch_size, len(all_sessions))}/{len(all_sessions)} sessions"
            )

    print(f"Finished computing all embeddings in {time.time() - start_time:.2f}s")
    return embeddings


# Initialize the cache with precomputed data
with driver.session() as session:
    print("Starting data precomputation...")
    cache.session_embeddings = session.execute_read(precompute_all_data)
    print("Data precomputation complete!")

In [None]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# Define job role categories
VET_ROLES = [
    "Vet/Vet Surgeon",
    "Assistant Vet",
    "Vet/Owner",
    "Clinical or other Director",
    "Locum Vet",
    "Academic",
]

NURSE_ROLES = ["Head Nurse/Senior Nurse", "Vet Nurse", "Locum RVN"]

BUSINESS = ["Practice Manager", "Practice Partner/Owner"]
# Other roles can attend any session
OTHER_ROLES = ["Student", "Receptionist", "Other (please specify)"]


# Define session embedding
def create_session_embedding(session, stream_descriptions=None):
    # Include title, synopsis_stripped, and theatre__name as required
    base_text = (
        f"{session['title']} {session['synopsis_stripped']} {session['theatre__name']}"
    )

    # Add stream descriptions if provided
    if stream_descriptions and len(stream_descriptions) > 0:
        stream_desc_text = " ".join(stream_descriptions)
        text = f"{base_text} {stream_desc_text}"
    else:
        text = base_text

    return model.encode(text)


# Embedding function for all sessions
def embed_all_sessions(tx):
    # Query all Sessions from past_year_dva, past_year_lva, Sessions_past_year (for compatibility) and this_year
    # Include the new sponsored_by and sponsored_session attributes
    query = """
    MATCH (s)
    WHERE s:Sessions_past_year_dva OR s:Sessions_past_year_lva OR s:Sessions_this_year OR s:Sessions_past_year
    RETURN s.session_id as session_id, s.title as title, 
           s.stream as stream, s.synopsis_stripped as synopsis_stripped,
           s.theatre__name as theatre__name, s.sponsored_by as sponsored_by,
           s.sponsored_session as sponsored_session, labels(s)[0] as type,
           CASE WHEN s.key_text IS NOT NULL THEN s.key_text ELSE '' END as key_text
    """
    sessions = tx.run(query).data()

    # Fetch all stream descriptions once to avoid multiple queries
    stream_query = """
    MATCH (s:Stream)
    RETURN s.stream as stream, s.description as description
    """
    stream_data = tx.run(stream_query).data()

    # Create a dictionary of stream descriptions for quick lookup
    stream_descriptions = {s["stream"].lower(): s["description"] for s in stream_data}

    embeddings = {}
    for s in sessions:
        # Process the stream field - split by semicolon and handle duplicates
        session_streams = []
        if s["stream"]:
            # Split the stream string and strip whitespace
            stream_list = [stream.strip().lower() for stream in s["stream"].split(";")]
            # Remove duplicates by converting to set and back to list
            stream_list = list(set(stream_list))

            # Get the description for each stream
            for stream in stream_list:
                if stream in stream_descriptions:
                    # session_streams.append(stream_descriptions[stream])
                    session_streams.append(stream)

        # Create embedding with the session data and stream descriptions
        embeddings[s["session_id"]] = {
            "type": s["type"],
            "embedding": create_session_embedding(s, session_streams),
            "theatre__name": s["theatre__name"],
            "stream": s["stream"],
            "title": s["title"],
            "sponsored_by": (
                s["sponsored_by"] if "sponsored_by" in s else "Not Sponsored"
            ),
            "sponsored_session": (
                s["sponsored_session"] if "sponsored_session" in s else "False"
            ),
            "key_text": s["key_text"] if "key_text" in s else "",
        }

    return embeddings


# Visitor-based similarity
def visitor_similarity(v1, v2):
    """
    Calculate similarity between two visitors using only specific attributes:
    - what_type_does_your_practice_specialise_in
    - job_role
    - organisation_type

    Only use attributes if they are not "NA", and apply defaults if all are NA.
    """
    # Attributes to compare
    attrs_to_compare = [
        "what_type_does_your_practice_specialise_in",
        "job_role",
        "organisation_type",
    ]

    # Count how many attributes are not "NA" for both visitors
    valid_attrs = []
    for attr in attrs_to_compare:
        # Check if attribute exists and is not "NA" for both visitors
        if attr in v1 and v1[attr] != "NA" and attr in v2 and v2[attr] != "NA":
            valid_attrs.append(attr)

    # If all three attributes are NA, use defaults
    # Apply defaults only to v1 (the current visitor we're finding recommendations for)
    defaults_applied = False
    if len(valid_attrs) == 0:
        # Apply defaults to v1
        v1_copy = dict(v1)  # Create a copy to avoid modifying the original
        v1_copy["job_role"] = "Vet/Owner"
        v1_copy["what_type_does_your_practice_specialise_in"] = "Mixed"

        # Re-evaluate which attributes to use
        valid_attrs = []
        for attr in attrs_to_compare:
            if (
                attr in v1_copy
                and v1_copy[attr] != "NA"
                and attr in v2
                and v2[attr] != "NA"
            ):
                valid_attrs.append(attr)

        # Use the copy for comparison
        v1 = v1_copy
        defaults_applied = True

    # Calculate similarity only on valid attributes
    if len(valid_attrs) > 0:
        matches = sum(1 for attr in valid_attrs if v1[attr] == v2[attr])
        similarity = matches / len(valid_attrs)
    else:
        # If still no valid attributes (unlikely but possible)
        similarity = 0.0

    return similarity


# Modularized function to get this year's sessions filtered by practice type and role
def get_filtered_sessions_for_visitor(tx, visitor_id, practice_type, job_role):
    """
    Get sessions for this year that are appropriate for the visitor's practice type and role.
    Uses direct Neo4j relationships for more accurate filtering.
    """
    # First filter out President-related sessions
    base_query = """
    MATCH (s:Sessions_this_year)
    WHERE NOT s.title = "BVA's President's Welcome" AND NOT s.title CONTAINS "President"
    """

    # If practice type is null, "no_data", or "NA", we can't apply practice-based rules
    if not practice_type or practice_type.lower() in ["no_data", "na"]:
        # Only apply role-based filtering
        if job_role in NURSE_ROLES:
            # Rule 3: Nurses can only get sessions with nursing, wellbeing, or welfare streams
            query = (
                base_query
                + """
            MATCH (s)-[:HAS_STREAM]->(stream:Stream)
            WHERE stream.stream IN ["nursing", "wellbeing", "welfare"]
            RETURN DISTINCT s.session_id as session_id
            """
            )
        elif job_role in VET_ROLES:
            # Rule 4: Vets cannot get nursing sessions
            query = (
                base_query
                + """
            MATCH (s)
            WHERE NOT EXISTS {
                MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                WHERE stream.stream = "nursing"
            }
            RETURN DISTINCT s.session_id as session_id
            """
            )
        else:
            # No specific role filtering
            query = (
                base_query
                + """
            RETURN DISTINCT s.session_id as session_id
            """
            )
    else:
        # Split practice type into tokens
        practice_types = [pt.strip().lower() for pt in practice_type.split(";")]

        if "equine" in practice_types or "mixed" in practice_types:
            # Rule 1: Equine/Mixed practices cannot get exotics, feline, exotic animal, farm, small animal sessions
            excluded_streams = [
                "exotics",
                "feline",
                "exotic animal",
                "farm",
                "small animal",
            ]

            if job_role in NURSE_ROLES:
                # Combine Rule 1 and Rule 3
                query = (
                    base_query
                    + """
                MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                WHERE stream.stream IN ["nursing", "wellbeing", "welfare"]
                AND NOT EXISTS {
                    MATCH (s)-[:HAS_STREAM]->(excluded:Stream)
                    WHERE toLower(excluded.stream) IN $excluded_streams
                }
                RETURN DISTINCT s.session_id as session_id
                """
                )
            elif job_role in VET_ROLES:
                # Combine Rule 1 and Rule 4
                query = (
                    base_query
                    + """
                MATCH (s)
                WHERE NOT EXISTS {
                    MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                    WHERE toLower(stream.stream) = "nursing" OR toLower(stream.stream) IN $excluded_streams
                }
                RETURN DISTINCT s.session_id as session_id
                """
                )
            else:
                # Only Rule 1
                query = (
                    base_query
                    + """
                MATCH (s)
                WHERE NOT EXISTS {
                    MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                    WHERE toLower(stream.stream) IN $excluded_streams
                }
                RETURN DISTINCT s.session_id as session_id
                """
                )

            return tx.run(query, excluded_streams=excluded_streams).data()

        elif any("small animal" in pt for pt in practice_types):
            # Rule 2: Small Animal practices cannot get equine, farm animal, farm, large animal sessions
            excluded_streams = ["equine", "farm animal", "farm", "large animal"]

            if job_role in NURSE_ROLES:
                # Combine Rule 2 and Rule 3
                query = (
                    base_query
                    + """
                MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                WHERE stream.stream IN ["nursing", "wellbeing", "welfare"]
                AND NOT EXISTS {
                    MATCH (s)-[:HAS_STREAM]->(excluded:Stream)
                    WHERE toLower(excluded.stream) IN $excluded_streams
                }
                RETURN DISTINCT s.session_id as session_id
                """
                )
            elif job_role in VET_ROLES:
                # Combine Rule 2 and Rule 4
                query = (
                    base_query
                    + """
                MATCH (s)
                WHERE NOT EXISTS {
                    MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                    WHERE toLower(stream.stream) = "nursing" OR toLower(stream.stream) IN $excluded_streams
                }
                RETURN DISTINCT s.session_id as session_id
                """
                )
            else:
                # Only Rule 2
                query = (
                    base_query
                    + """
                MATCH (s)
                WHERE NOT EXISTS {
                    MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                    WHERE toLower(stream.stream) IN $excluded_streams
                }
                RETURN DISTINCT s.session_id as session_id
                """
                )

            return tx.run(query, excluded_streams=excluded_streams).data()

        else:
            # No specific practice type filtering
            if job_role in NURSE_ROLES:
                # Only Rule 3
                query = (
                    base_query
                    + """
                MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                WHERE stream.stream IN ["nursing", "wellbeing", "welfare"]
                RETURN DISTINCT s.session_id as session_id
                """
                )
            elif job_role in VET_ROLES:
                # Only Rule 4
                query = (
                    base_query
                    + """
                MATCH (s)
                WHERE NOT EXISTS {
                    MATCH (s)-[:HAS_STREAM]->(stream:Stream)
                    WHERE stream.stream = "nursing"
                }
                RETURN DISTINCT s.session_id as session_id
                """
                )
            else:
                # No filtering
                query = (
                    base_query
                    + """
                RETURN DISTINCT s.session_id as session_id
                """
                )

    return tx.run(query).data()


# Function to check if a session is appropriate for a visitor's job role
def is_session_appropriate_for_role(session_stream, job_role):
    """
    Determine if a session is appropriate for a visitor's job role.

    Rule 3: NURSE_ROLES can only be recommended sessions with streams: [nursing, wellbeing, welfare]
    Rule 4: VET_ROLES cannot be recommended sessions with streams: [nursing]

    Session_stream can be a semicolon-separated list.
    """
    if not session_stream or not job_role:
        return True, "No role or session stream restrictions"

    # Split session stream into tokens
    session_streams = [ss.strip().lower() for ss in session_stream.split(";")]

    # Rule 3: For Nurse roles
    if job_role in NURSE_ROLES:
        nurse_streams = ["nursing", "wellbeing", "welfare"]
        # Check if ANY session stream matches ANY nurse stream
        has_nurse_content = any(
            any(nurse_stream in ss for nurse_stream in nurse_streams)
            for ss in session_streams
        )
        if not has_nurse_content:
            return False, "Not suitable for nursing roles"
        return True, "Nursing content"

    # Rule 4: For Vet roles
    elif job_role in VET_ROLES:
        # Check if ANY session stream contains "nursing"
        if any("nursing" in ss for ss in session_streams):
            return False, "Not suitable for veterinary roles"

    return True, "Suitable for role"


# Modularized function to get visitor information
def get_visitor_info(tx, visitor_id):
    """Get visitor information including practice type and job role."""
    visitor_query = """
    MATCH (v:Visitor_this_year {BadgeId: $visitor_id})
    RETURN v
    """
    visitor_data = tx.run(visitor_query, visitor_id=visitor_id).single()
    if not visitor_data:
        return None, None, None

    visitor = visitor_data["v"]
    assisted = visitor.get("assist_year_before", "0")
    practice_type = visitor.get("what_type_does_your_practice_specialise_in", "")
    job_role = visitor.get("job_role", "")

    return visitor, assisted, practice_type, job_role


# Modularized function to get past sessions for a visitor who assisted last year
def get_past_sessions(tx, visitor_id):
    """Get sessions the visitor attended last year."""
    query_past = """
    MATCH (v:Visitor_this_year {BadgeId: $visitor_id})-[:Same_Visitor]->(vp_bva:Visitor_last_year_bva)-[:attended_session]->(sp_bva:Sessions_past_year)
    RETURN sp_bva.session_id as session_id
    UNION
    MATCH (v:Visitor_this_year {BadgeId: $visitor_id})-[:Same_Visitor]->(vp_lva:Visitor_last_year_lva)-[:attended_session]->(sp_lva:Sessions_past_year)
    RETURN sp_lva.session_id as session_id
    """
    return tx.run(query_past, visitor_id=visitor_id).data()


# Modularized function to get similar visitors
def get_similar_visitors(tx, visitor):
    """Find similar visitors who attended last year."""
    all_visitors = tx.run(
        """
        MATCH (v:Visitor_this_year)
        WHERE v.assist_year_before = '1'
        RETURN v
        """
    ).data()

    # Calculate similarities
    similarities = []
    for vdata in all_visitors:
        v_compare = vdata["v"]
        sim = visitor_similarity(visitor, v_compare)
        similarities.append((v_compare["BadgeId"], sim))

    # Sort by similarity and get top 3 similar visitors
    similarities.sort(key=lambda x: -x[1])
    return [sid for sid, _ in similarities[:3]]


# Modularized function to get sessions attended by similar visitors
def get_similar_visitor_sessions(tx, similar_visitor_badge_ids):
    """Get sessions attended by similar visitors."""
    similar_visitor_sessions = []

    for similar_vid in similar_visitor_badge_ids:
        query_sim_past = """
            MATCH (v:Visitor_this_year {BadgeId: $similar_vid})-[:Same_Visitor]->(vp)-[:attended_session]->(sp:Sessions_past_year)
            WHERE vp:Visitor_last_year_bva OR vp:Visitor_last_year_lva
            RETURN sp.session_id AS session_id
        """
        sessions = tx.run(query_sim_past, similar_vid=similar_vid).data()
        similar_visitor_sessions.extend(sessions)

    return similar_visitor_sessions


# Modularized function to calculate session similarities
def calculate_session_similarities(
    past_sessions, session_embeddings, this_year_sessions, practice_type, job_role
):
    """Calculate similarities between past sessions and this year's sessions."""
    recommendations = []

    for past_sess in past_sessions:
        # Skip if the session_id doesn't exist in embeddings
        if past_sess["session_id"] not in session_embeddings:
            continue

        past_emb = session_embeddings[past_sess["session_id"]]["embedding"]

        for sid, data in this_year_sessions.items():
            title = data.get("title", "")

            # Apply practice type filtering
            is_appropriate_practice, practice_reason = (
                is_session_appropriate_for_practice_type(
                    data["stream"], practice_type, title
                )
            )
            if not is_appropriate_practice:
                continue

            # Apply role-based filtering
            is_appropriate_role, role_reason = is_session_appropriate_for_role(
                data["stream"], job_role
            )
            if not is_appropriate_role:
                continue

            # Calculate similarity based on embeddings (Rule 7)
            sim = cosine_similarity([past_emb], [data["embedding"]])[0][0]

            # Combine reasons for recommendation
            reason = f"{role_reason}. {practice_reason}"

            recommendations.append(
                {"session_id": sid, "similarity": sim, "reason": reason}
            )

    return recommendations


# Refactored recommend_sessions function
def recommend_sessions(tx, visitor_id, session_embeddings):
    """
    Recommend sessions for a visitor based on their profile and history.

    Implementation of new rules:
    1. Filter by practice type (Equine/Mixed, Small Animal)
    2. Filter by job role (Nurse, Vet)
    3. Base scoring on embedding similarity
    4. Keep track of reasons for recommendations
    """
    # Get visitor information
    visitor, assisted, practice_type, job_role = get_visitor_info(tx, visitor_id)
    if not visitor:
        return []

    # Get filtered sessions that match the visitor's practice type and role
    filtered_sessions_data = get_filtered_sessions_for_visitor(
        tx, visitor_id, practice_type, job_role
    )
    filtered_session_ids = {s["session_id"] for s in filtered_sessions_data}

    # Filter this year's sessions to only include those matching our criteria
    this_year_sessions = {
        k: v
        for k, v in session_embeddings.items()
        if v["type"] == "Sessions_this_year" and k in filtered_session_ids
    }

    recommendations = []

    if assisted == "1":
        # Case 1: Visitor attended last year
        past_sessions = get_past_sessions(tx, visitor_id)

        # Calculate similarities with past sessions
        for past_sess in past_sessions:
            # Skip if the session_id doesn't exist in embeddings
            if past_sess["session_id"] not in session_embeddings:
                continue

            past_emb = session_embeddings[past_sess["session_id"]]["embedding"]

            for sid, data in this_year_sessions.items():
                # Calculate similarity based on embeddings (Rule 7)
                sim = cosine_similarity([past_emb], [data["embedding"]])[0][0]

                # Create recommendation reason
                if job_role in NURSE_ROLES:
                    reason = "Nursing content based on your past attendance"
                elif job_role in VET_ROLES:
                    reason = "Veterinary content based on your past attendance"
                else:
                    reason = "Based on your past attendance"

                recommendations.append(
                    {"session_id": sid, "similarity": sim, "reason": reason}
                )

    else:
        # Case 2: New visitor - find similar visitors with history
        similar_visitor_badge_ids = get_similar_visitors(tx, visitor)

        # Get sessions attended by similar visitors
        similar_visitor_sessions = get_similar_visitor_sessions(
            tx, similar_visitor_badge_ids
        )

        # Calculate similarities with sessions attended by similar visitors
        for past_sess in similar_visitor_sessions:
            # Skip if the session_id doesn't exist in embeddings
            if past_sess["session_id"] not in session_embeddings:
                continue

            past_emb = session_embeddings[past_sess["session_id"]]["embedding"]

            for sid, data in this_year_sessions.items():
                # Calculate similarity based on embeddings (Rule 7)
                sim = cosine_similarity([past_emb], [data["embedding"]])[0][0]

                # Create recommendation reason
                if job_role in NURSE_ROLES:
                    reason = "Nursing content based on similar visitors"
                elif job_role in VET_ROLES:
                    reason = "Veterinary content based on similar visitors"
                else:
                    reason = "Based on similar visitors with your profile"

                recommendations.append(
                    {"session_id": sid, "similarity": sim, "reason": reason}
                )

    # Remove duplicates (keeping the highest similarity score)
    session_to_best_rec = {}
    for rec in recommendations:
        sid = rec["session_id"]
        if (
            sid not in session_to_best_rec
            or rec["similarity"] > session_to_best_rec[sid]["similarity"]
        ):
            session_to_best_rec[sid] = rec

    # Convert back to list and sort by similarity
    unique_recommendations = list(session_to_best_rec.values())
    unique_recommendations.sort(key=lambda x: -x["similarity"])

    return unique_recommendations


# Modified function to filter sessions with similarity scores
def filter_sessions_by_visitor_stream_relationships(
    tx, visitor_id, session_recommendations
):
    """
    Filter sessions to keep only those where the visitor has a relationship
    (specialization_to_stream or job_to_stream) to any Stream that has a HAS_STREAM
    relationship to the session. Preserves similarity scores.
    """
    if not session_recommendations:
        return []

    # Extract just the session IDs for the query
    session_ids = [rec["session_id"] for rec in session_recommendations]

    # Query to find sessions with valid stream relationships to the visitor
    query = """
    MATCH (v:Visitor_this_year {BadgeId: $visitor_id})
    MATCH (s:Sessions_this_year)
    WHERE s.session_id IN $session_ids
    MATCH (s)-[:HAS_STREAM]->(stream:Stream)<-[r]-(v)
    WHERE type(r) IN ['specialization_to_stream', 'job_to_stream']
    RETURN DISTINCT s.session_id as session_id
    """

    results = tx.run(query, visitor_id=visitor_id, session_ids=session_ids).data()
    valid_session_ids = set(r["session_id"] for r in results)

    # Filter the original recommendations to keep only valid sessions with their scores
    filtered_recommendations = [
        rec for rec in session_recommendations if rec["session_id"] in valid_session_ids
    ]

    return filtered_recommendations


# Modified get_session_attributes to include similarity scores
def get_session_attributes(session_recommendations):
    # Handle empty recommendations to prevent errors
    if not session_recommendations:
        return []

    # Extract just the session IDs for the query
    session_ids = [rec["session_id"] for rec in session_recommendations]

    # Create a mapping of session_id to similarity score for later use
    similarity_map = {
        rec["session_id"]: {"similarity": rec["similarity"], "reason": rec["reason"]}
        for rec in session_recommendations
    }

    query = """
    MATCH (s:Sessions_this_year)
    WHERE s.session_id IN $session_ids
    RETURN s {
        .stream,
        .session_id,
        .title,
        .synopsis_stripped,
        .end_time,
        .start_time,
        .date,
        .theatre__name,
        .sponsored_by,
        .sponsored_session
    } AS session_details
    """

    with driver.session() as session:
        results = session.run(query, session_ids=session_ids)
        session_details = []

        for record in results:
            details = record["session_details"]
            # Add similarity score and reason to the session details
            session_id = details["session_id"]
            if session_id in similarity_map:
                details["similarity_score"] = similarity_map[session_id]["similarity"]
                details["recommendation_reason"] = similarity_map[session_id]["reason"]
            session_details.append(details)

    # Sort by similarity score (highest first)
    session_details.sort(key=lambda x: x.get("similarity_score", 0), reverse=True)

    return session_details


# Cached embeddings at the module level
_session_embeddings = None


def get_recommendations(visitor_id, max_recommendations=None):
    """
    Main function to get recommendations for a visitor.
    Returns two sets of session details: filtered and unfiltered, both including similarity scores.

    Parameters:
    - visitor_id: The ID of the visitor to get recommendations for
    - max_recommendations: Maximum number of recommendations to return (None = no limit)

    Returns:
    - A tuple (filtered_recommendations, unfiltered_recommendations) with session details
    """
    global _session_embeddings

    # Create session embeddings once if they don't exist
    if _session_embeddings is None:
        with driver.session() as session:
            _session_embeddings = session.execute_read(embed_all_sessions)

    # Use the cached embeddings for recommendations
    with driver.session() as session:
        # Get initial recommended sessions with similarity scores
        recommended_sessions = session.execute_read(
            recommend_sessions,
            visitor_id=visitor_id,
            session_embeddings=_session_embeddings,
        )

        # Limit the number of recommendations if specified
        if max_recommendations is not None and max_recommendations > 0:
            recommended_sessions = recommended_sessions[:max_recommendations]

        # Get the details of the unfiltered recommended sessions
        unfiltered_recommendations = get_session_attributes(recommended_sessions)

        # Apply the additional filter for stream relationships
        filtered_sessions = session.execute_read(
            filter_sessions_by_visitor_stream_relationships,
            visitor_id=visitor_id,
            session_recommendations=recommended_sessions,
        )

        # Limit the filtered recommendations if specified
        if max_recommendations is not None and max_recommendations > 0:
            filtered_sessions = filtered_sessions[:max_recommendations]

    # Get the details of the filtered recommended sessions
    filtered_recommendations = get_session_attributes(filtered_sessions)

    return filtered_recommendations, unfiltered_recommendations

In [None]:
csv_file_path = "data/bva/output/df_reg_demo_this.csv"

In [None]:
data = pd.read_csv(csv_file_path)
len(data)

In [None]:
list_badgeId_this = list(data["BadgeId"].unique())
len(list_badgeId_this)

In [None]:
list_badgeId_this = list_badgeId_this[:100]

In [None]:
pa_recommendations, pa_recommendations_full = get_batch_recommendations(
    list_badgeId_this,
    max_recommendations=10,
    num_workers=4,  # Adjust based on your CPU cores
)

print(f"Total badges processed: {len(pa_recommendations)}")

In [None]:
import json

In [None]:
for key in pa_recommendations.keys():
    for ele in pa_recommendations[key]:
        ele["similarity_score"] = float(ele["similarity_score"])

In [None]:
count = 0
for key in pa_recommendations.keys():
    if len(pa_recommendations[key]) == 0:
        count += 1
print(count)

In [None]:
for key in pa_recommendations_full.keys():
    for ele in pa_recommendations_full[key]:
        ele["similarity_score"] = float(ele["similarity_score"])
count = 0
for key in pa_recommendations_full.keys():
    if len(pa_recommendations_full[key]) == 0:
        count += 1
print(count)

In [None]:
with open("data/bva/bva_pa_recomendations.json", "w") as f:
    json.dump(pa_recommendations, f, indent=4)
with open("data/bva/bva_pa_recomendations_full.json", "w") as f:
    json.dump(pa_recommendations_full, f, indent=4)

In [None]:
pa_recommendations["YCRZ6F4"][0].keys()

In [None]:
def transform_recommendations_to_dataframe(pa_recommendations):
    """
    Transforms a dictionary of session recommendations to a Pandas DataFrame.

    Args:
        pa_recommendations (dict): A dictionary where keys are badge IDs (strings)
            and values are lists of session dictionaries.  Each session dictionary
            contains information about a recommended session.

    Returns:
        pandas.DataFrame: A DataFrame with columns 'badgeid', 'session_id', 'stream',
            'title', 'end_time', 'synopsis_stripped', 'start_time', 'date',
            and 'theatre__name'. Returns an empty DataFrame if the input dictionary is empty.
    """

    data = []  # List to hold the rows of the DataFrame
    for badgeid, session_list in pa_recommendations.items():
        for session in session_list:
            row = {
                "badgeid": badgeid,
                "session_id": session.get("session_id", None),  # Use .get() for safety
                "stream": session.get("stream", None),
                "title": session.get("title", None),
                "synopsis_stripped": session.get("synopsis_stripped", None),
                "date": session.get("date", None),
                "start_time": session.get("start_time", None),
                "end_time": session.get("end_time", None),
                "theatre__name": session.get("theatre__name", None),
                "sponsored_by": session.get("sponsored_by", None),
                "similarity_score": session.get("similarity_score", None),
                "recommendation_reason": session.get("recommendation_reason", None),
            }
            data.append(row)

    df = pd.DataFrame(data)
    return df


df = transform_recommendations_to_dataframe(pa_recommendations)
df_full = transform_recommendations_to_dataframe(pa_recommendations_full)

In [None]:
len(df), len(df_full)

In [None]:
import pandas as pd
import numpy as np


def flag_overlapping_sessions(df):
    # Create a copy to avoid modifying the original
    df = df.copy()

    # Convert string time columns to datetime
    df["start_datetime"] = pd.to_datetime(df["date"] + " " + df["start_time"])
    df["end_datetime"] = pd.to_datetime(df["date"] + " " + df["end_time"])

    # Initialize the overlapping_sessions column
    df["overlapping_sessions"] = None

    # Process each badge group
    for badge_id, group in df.groupby("badgeid"):
        # Skip if only one session
        if len(group) <= 1:
            continue

        # Get indices in the original dataframe
        group_indices = group.index

        # For each session in the group
        for i, idx in enumerate(group_indices):
            # Get current session times
            current_start = df.loc[idx, "start_datetime"]
            current_end = df.loc[idx, "end_datetime"]
            current_id = df.loc[idx, "session_id"]

            # Create mask for overlapping sessions (vectorized comparison)
            # A session overlaps if it starts before current ends AND ends after current starts
            mask = (
                (group["start_datetime"] < current_end)
                & (group["end_datetime"] > current_start)
                & (group["session_id"] != current_id)
            )

            # Get overlapping session IDs
            overlapping_ids = group.loc[mask, "session_id"].tolist()

            # Update if overlaps found
            if overlapping_ids:
                df.at[idx, "overlapping_sessions"] = "|".join(map(str, overlapping_ids))

    # Drop the temporary datetime columns
    df.drop(["start_datetime", "end_datetime"], axis=1, inplace=True)

    return df


# Apply the function to your DataFrame
df = flag_overlapping_sessions(df)
df_full = flag_overlapping_sessions(df_full)

In [None]:
df.overlapping_sessions.value_counts()

In [None]:
df.head(2)

In [None]:
df_full.recommendation_reason.unique()

In [None]:
df[:100].to_csv("data/bva/bva_pa_recomendations_example.csv", index=False)

In [None]:
df.to_csv("data/bva/bva_pa_recomendations.csv", index=False)
df_full.to_csv("data/bva/bva_pa_recomendations_full.csv", index=False)

# ADD registration Demo Information

In [None]:
data.columns

In [None]:
data.columns = [
    "Email",
    "Email_domain",
    "Company",
    "JobTitle",
    "Country",
    "BadgeType",
    "ShowRef",
    "badgeid",
    "Source",
    "Days_since_registration",
    "assist_year_before",
    "BadgeId_last_year_bva",
    "BadgeId_last_year_lva",
    "what_type_does_your_practice_specialise_in",
    "organisation_type",
    "job_role",
]

In [None]:
data_final = pd.merge(df, data, on=["badgeid"], how="left")
data_final_full = pd.merge(df_full, data, on=["badgeid"], how="left")

In [None]:
len(df), len(data_final), len(df_full), len(data_final_full)

In [None]:
data_final.to_csv("data/bva/bva_pa_recomendations_with_demo.csv", index=False)
data_final_full.to_csv("data/bva/bva_pa_recomendations_with_demo_full.csv", index=False)

In [None]:
len(data_final_full["badgeid"].unique())

In [None]:
# job_role
data_final_full["job_role"].value_counts()