In [None]:
import re
import os
import logging
from neo4j import GraphDatabase
import fitz  # PyMuPDF for PDF reading
from typing import List, Any, Dict

# Langchain imports
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain.chains.router import MultiRetrievalQAChain
from langchain_core.callbacks import CallbackManagerForRetrieverRun

# For custom router prompt in MultiRetrievalQAChain
from langchain.chains.router.llm_router import RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import ROUTER_TEMPLATE


# Optional NLP imports (assuming they are the same)
# ...

# Configuration (assuming they are the same)
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASS = os.getenv("NEO4J_PASS", "password")
pdf_path = os.getenv("PDF_PATH", "your_pdf_path.pdf") # Update this

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Neo4j Data Ingestion Script (Placeholder - Use your full script) ---
neo4j_ingestion_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))
def populate_data_from_pdf(): logging.info("Placeholder: PDF data population.")
def save_to_neo4j(): logging.info("Placeholder: Saving to Neo4j.")
# --- End of Ingestion Placeholder ---

def create_neo4j_ft_index(driver_param: GraphDatabase.driver):
    with driver_param.session() as session:
        try:
            index_query = """
            CREATE FULLTEXT INDEX ruleTextIndex IF NOT EXISTS
            FOR (n:RuleSection|Clarification) ON EACH [n.title, n.text]
            """
            session.run(index_query)
            logging.info("Full-text index 'ruleTextIndex' ensured or already exists.")
        except Exception as e:
            logging.error(f"Error ensuring Neo4j full-text index: {e}")

class SimpleNeo4jRetriever(BaseRetriever):
    driver: Any 
    limit: int = 5

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        results: List[Document] = []
        logging.debug(f"SimpleNeo4jRetriever _get_relevant_documents for user query: '{query}'")
        
        # FIX 1: Changed $query to $search_term in Cypher and in session.run call
        cypher_statement = """
        CALL db.index.fulltext.queryNodes('ruleTextIndex', $search_term) YIELD node, score
        WITH node, score 
        WHERE node.id IS NOT NULL AND (node.title IS NOT NULL OR node.text IS NOT NULL)
        RETURN node.id AS id, 
               coalesce(node.title, "") AS title, 
               coalesce(node.text, "") AS text, 
               score
        ORDER BY score DESC
        LIMIT $limit
        """
        with self.driver.session() as session:
            try:
                records = session.run(cypher_statement, search_term=query, limit=self.limit) 
                for rec in records:
                    page_content_parts = []
                    if rec["title"]: page_content_parts.append(rec["title"])
                    if rec["text"]: page_content_parts.append(rec["text"])
                    content = "\n\n".join(page_content_parts).strip()
                    if not content: continue
                    metadata = {"rule_id": str(rec["id"]), "score": rec["score"]}
                    results.append(Document(page_content=content, metadata=metadata))
            except Exception as e:
                logging.error(f"Neo4j full-text search failed in _get_relevant_documents: {e}", exc_info=True) # Added exc_info
                if "No such index" in str(e) or "index does not exist" in str(e):
                    logging.error("The 'ruleTextIndex' does not exist.")
        logging.debug(f"SimpleNeo4jRetriever _get_relevant_documents returning {len(results)} documents.")
        return results

class MockVectorRetriever(BaseRetriever):
    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        logging.debug(f"MockVectorRetriever _get_relevant_documents for query: '{query}'")
        return [Document(page_content=f"Mock vector content for: {query}", metadata={"rule_id": "Rule Vector.Mock"})]

def run_qa_system():
    qa_neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))
    create_neo4j_ft_index(qa_neo4j_driver)
    llm = ChatOpenAI(model_name="gpt-4o", temperature=0)

    qa_prompt_for_input_key = PromptTemplate(
        input_variables=["context", "input"],
        template=( "...") # Your prompt template here
    )
    qa_prompt_for_query_key = PromptTemplate(
        input_variables=["context", "query"],
        template=( "...") # Your prompt template here
    )
    # Fill in your prompt templates:
    qa_prompt_for_input_key = PromptTemplate(
        input_variables=["context", "input"],
        template=(
            "Use the following context to answer the question precisely. "
            "If it’s not in the context, say “I don’t know.”\n\n"
            "Context:\n{context}\n\nQuestion: {input}\nAnswer:"
        )
    )
    qa_prompt_for_query_key = PromptTemplate(
        input_variables=["context", "query"],
        template=(
            "Use the following context to answer the question precisely. "
            "If it’s not in the context, say “I don’t know.”\n\n"
            "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
        )
    )


    combine_docs_chain_for_input = create_stuff_documents_chain(llm=llm, prompt=qa_prompt_for_input_key)
    vector_retriever = MockVectorRetriever() 
    graph_retriever_instance = SimpleNeo4jRetriever(driver=qa_neo4j_driver, limit=5)
    vector_chain = create_retrieval_chain(vector_retriever, combine_docs_chain_for_input)
    graph_chain = create_retrieval_chain(graph_retriever_instance, combine_docs_chain_for_input)

    # FIX 2: Custom router prompt for MultiRetrievalQAChain to use "query"
    custom_router_template_str = ROUTER_TEMPLATE.replace("{input}", "{query}")
    router_prompt = PromptTemplate(
        template=custom_router_template_str,
        input_variables=["query", "destinations"], # Changed from "input"
        output_parser=RouterOutputParser(),
    )

    multi_chain = MultiRetrievalQAChain.from_retrievers(
        llm=llm, 
        retriever_infos=[
            {"name": "vector_search", "description": "Good for semantic similarity...", "retriever": vector_retriever},
            {"name": "graph_search",  "description": "Good for specific rule lookup...", "retriever": graph_retriever_instance},
        ],
        default_retriever=vector_retriever,
        default_prompt=qa_prompt_for_query_key, # For the QA part, uses "query"
        default_chain_llm=llm, 
        verbose=True,
        router_prompt=router_prompt # Use the custom router prompt
    )

    QUERIES = [
        "What happens when a ball is moved by an outside influence?",
        "Which rules reference Rule 16?",
    ]
    run_retrieval_comparison_tests(QUERIES, vector_chain, graph_chain, multi_chain)
    qa_neo4j_driver.close()

def extract_answer(output: Any) -> str:
    # ... (same as before)
    if isinstance(output, dict):
        for key in ("answer", "result", "text"): 
            if key in output and output[key] is not None:
                return str(output[key])
        for value in output.values(): # Fallback
            if isinstance(value, str): return value
        return f"<NO_ANSWER_KEY_IN_DICT: {output}>"
    return str(output)


def run_retrieval_comparison_tests(queries_list, vc, gc, mc):
    # FIX 2 (part 2): Call Hybrid chain with "query" as input key
    test_scenarios = [
        ("Vector-only", vc, "input"),
        ("Graph-only",  gc, "input"),
        ("Hybrid",      mc, "query"), # Changed to "query"
    ]
    for q_idx, q_text in enumerate(queries_list):
        print(f"\n=== Query {q_idx+1}: {q_text!r} ===")
        for label, chain_instance, input_key_name in test_scenarios:
            ans_output = "<NOT_RUN>"
            try:
                logging.info(f"Running chain '{label}' with input key '{input_key_name}' for query: '{q_text}'")
                input_payload = {input_key_name: q_text}
                # A bit of introspection to see what keys the chain expects
                # if hasattr(chain_instance, 'input_keys'):
                #    logging.debug(f"Chain '{label}' expects input_keys: {chain_instance.input_keys}")
                # else:
                #    logging.debug(f"Chain '{label}' does not have 'input_keys' attribute.")

                chain_response = chain_instance.invoke(input_payload)
                ans_output = extract_answer(chain_response)
                print(f"{label.ljust(14)}: {ans_output}")
            except Exception as e:
                error_detail = f"<ERROR: {type(e).__name__} - {e}>"
                logging.error(f"Error running {label} for query '{q_text}': {error_detail}", exc_info=True) 
                print(f"{label.ljust(14)}: {error_detail}")
        print("-" * 80)

if __name__ == "__main__":
    # print("Starting Neo4j data ingestion process...")
    # save_to_neo4j() # Ensure your full ingestion code is uncommented here if needed
    # print("Neo4j data ingestion process complete.")
    if 'neo4j_ingestion_driver' in globals() and neo4j_ingestion_driver:
         try: neo4j_ingestion_driver.close(); logging.info("Neo4j ingestion driver closed.")
         except: pass
    print("\nStarting QA System...")
    run_qa_system()