In [None]:
import os
from neo4j import GraphDatabase # For driver
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain.chains import create_stuff_documents_chain, create_retrieval_chain, MultiRetrievalQAChain

# Assume 'driver' is already configured as per your PDF processing script
# NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
# NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
# NEO4J_PASS = os.getenv("NEO4J_PASS", "password")
# driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))


# Placeholder for your actual vector_retriever.
# This needs to be properly initialized, e.g., from ChromaDB with OpenAIEmbeddings.
# Ensure documents in your vector store have 'rule_id' in their metadata.
from langchain_community.vectorstores import Chroma # Example
from langchain_openai import OpenAIEmbeddings # Example

# Example: This is how you might set up your vector_retriever
# Make sure this is done *before* graph_retriever if it depends on it.
# try:
#     embeddings = OpenAIEmbeddings()
#     # Ensure your Chroma DB was populated with documents having 'rule_id' in metadata
#     vector_store = Chroma(persist_directory="./chroma_db_usga", embedding_function=embeddings)
#     vector_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
# except Exception as e:
#     print(f"Error setting up vector retriever: {e}")
#     print("Please ensure your vector store is correctly initialized and populated.")
#     # Fallback to a dummy retriever for the script to run, replace with your actual setup
class DummyVectorRetriever(BaseRetriever): # To make the script runnable if vector_retriever is not set up
    def _get_relevant_documents(self, query: str, *, run_manager = None) -> List[Document]:
        print(f"[DummyVectorRetriever] Warning: Using dummy vector retriever. Query: {query}")
        return [Document(page_content="Dummy context from vector store.", metadata={"rule_id":"Rule 0.0"})]
    def get_relevant_documents(self, query: str) -> List[Document]:
        return self._get_relevant_documents(query)
vector_retriever = DummyVectorRetriever() # Replace with your actual vector_retriever

# Instantiate your Neo4jHybridRetriever for the graph_retriever role
# This means the "Graph-only" chain will actually use your hybrid logic.
# If you want a PURELY graph (e.g., only full-text) retriever, you'd define a different class.
graph_retriever = Neo4jHybridRetriever(driver=driver, vector_retriever=vector_retriever, limit=5)

# ─── Build QA chains ──────────────────────────────────────────────────────
llm = ChatOpenAI(model_name="gpt-4o", temperature=0)

# Prompt for chains that expect "input" (e.g., vector_chain, graph_chain)
standard_prompt = PromptTemplate(
    input_variables=["context", "input"],
    template=(
        "Use the following context to answer the question precisely. "
        "If it’s not in the context, say “I don’t know.”\n\n"
        "Context:\n{context}\n\nQuestion: {input}\nAnswer:"
    )
)

# Prompt for MultiRetrievalQAChain which typically uses "query" for the question
multi_retrieval_prompt = PromptTemplate(
    input_variables=["context", "query"], # Use "query" here
    template=(
        "Use the following context to answer the question precisely. "
        "If it’s not in the context, say “I don’t know.”\n\n"
        "Context:\n{context}\n\nQuestion: {query}\nAnswer:" # Use "query" here
    )
)

combine_documents_chain = create_stuff_documents_chain(llm=llm, prompt=standard_prompt)

vector_chain  = create_retrieval_chain(vector_retriever, combine_documents_chain)
# graph_chain will use Neo4jHybridRetriever. Its prompt (standard_prompt) expects "input".
graph_chain   = create_retrieval_chain(graph_retriever, combine_documents_chain)

# ─── Multi-retrieval router ─────────────────────────────────────────────
# MultiRetrievalQAChain uses RetrievalQA chains internally.
# RetrievalQA by default uses "query" as input key and "result" as output key.
# The combine_documents_chain within RetrievalQA will be fed "query" and "context".
# So, the default_prompt for MultiRetrievalQAChain must expect "query".
multi_chain = MultiRetrievalQAChain.from_retrievers(
    llm=llm,
    retriever_infos=[
        {"name": "vector_search", "description": "Good for semantic similarity search over rules text using ADA embeddings and ChromaDB", "retriever": vector_retriever},
        {"name": "graph_database_search",  "description": "Good for queries involving rule relationships, specific rule IDs, or full-text search within rules and clarifications in Neo4j", "retriever": graph_retriever},
    ],
    default_retriever=vector_retriever, # You can choose graph_retriever or your custom Neo4jHybridRetriever too
    default_prompt=multi_retrieval_prompt, # Use the prompt that expects "query"
    # default_chain_llm=llm, # This is deprecated, llm is passed directly
    verbose=True,
)

# ─── Test Queries ───────────────────────────────────────────────────────
QUERIES = [
    "What happens when a ball is moved by an outside influence?",
    "Which rules reference Rule 16?", # This might be better with a graph-specific query generator if not in text
    "What if the normal free relief from an unplayable lie is not enough space for an aide to help a handicapped person aim.",
    "What are some example reliefs for a handicapped player?",
    "What is the penalty for playing the wrong ball?",
    "Describe Rule 4 in the Rules of Golf.",
    "When should you replace a damaged ball?",
    "What is the difference between Rule 3 and Rule 4?",
    "What happens when a ball is moved by an outside influence?", # Duplicate for testing consistency
    "What is the penalty for playing a ball that is not yours?",
]

def extract_answer(output: any) -> str:
    if isinstance(output, dict):
        # For create_retrieval_chain, the answer is in 'answer'
        # For MultiRetrievalQAChain, the answer is in 'result'
        for key in ("answer", "result"):
            if key in output:
                return str(output[key])
        # Fallback if keys are different
        return str(next(iter(output.values()), "Error: Could not extract answer from dict"))
    return str(output)

def run_retrieval_comparison(queries):
    chains_to_test = [ # Renamed to avoid conflict
        ("Vector-only", vector_chain),
        ("Graph-only (Hybrid Retriever)",  graph_chain), # Clarified name as it uses Neo4jHybridRetriever
        ("Multi-Retriever (Router)", multi_chain),
    ]
    for q in queries:
        print(f"\n=== Query: {q!r} ===")
        for label, chain_obj in chains_to_test: # Renamed variables
            ans = ""
            try:
                if label == "Multi-Retriever (Router)":
                    # MultiRetrievalQAChain expects "query" as input key
                    out = chain_obj.invoke({"query": q})
                else:
                    # create_retrieval_chain based chains expect "input"
                    out = chain_obj.invoke({"input": q})
                
                ans = extract_answer(out)

            except Exception as e:
                ans = f"<ERROR: {e}>"
            print(f"{label.ljust(30)}: {ans}")
        print("-" * 80)

if __name__ == "__main__":
    # Ensure NEO4J variables are set in your environment or defined here
    NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
    NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
    NEO4J_PASS = os.getenv("NEO4J_PASS", "password") # Replace with your actual password or env var

    # Check if OPENAI_API_KEY is set
    if not os.getenv("OPENAI_API_KEY"):
        print("Error: OPENAI_API_KEY environment variable not set.")
        exit()
        
    try:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))
        driver.verify_connectivity() 
        print("Neo4j connection successful.")
    except Exception as e:
        print(f"Error connecting to Neo4j: {e}")
        print("Please ensure Neo4j is running and credentials are correct.")
        exit()

    # --- Replace DummyVectorRetriever with your actual vector_retriever setup ---
    # Example:
    try:
        print("Setting up vector retriever...")
        # This is an example, use your actual vector store and embeddings
        # Ensure your vector store (e.g., Chroma) is populated with documents
        # that include 'rule_id' in their metadata for Neo4jHybridRetriever to work optimally.
        from langchain_community.vectorstores import Chroma
        from langchain_openai import OpenAIEmbeddings
        
        embeddings = OpenAIEmbeddings() # Requires OPENAI_API_KEY
        # Replace "./chroma_db_usga" with your actual persist directory
        # If the directory doesn't exist or is empty, this might lead to poor results.
        vector_store = Chroma(persist_directory="./chroma_db_usga", embedding_function=embeddings)
        vector_retriever = vector_store.as_retriever(search_kwargs={"k": 3}) # k is number of docs
        print("Vector retriever set up successfully from Chroma.")
        
        # Re-initialize graph_retriever if it depends on the actual vector_retriever
        graph_retriever = Neo4jHybridRetriever(driver=driver, vector_retriever=vector_retriever, limit=5)
        
        # Re-initialize chains that depend on the updated retrievers
        vector_chain  = create_retrieval_chain(vector_retriever, combine_documents_chain)
        graph_chain   = create_retrieval_chain(graph_retriever, combine_documents_chain)
        multi_chain = MultiRetrievalQAChain.from_retrievers(
            llm=llm,
            retriever_infos=[
                {"name": "vector_search", "description": "Good for semantic similarity search over rules text using ADA embeddings and ChromaDB", "retriever": vector_retriever},
                {"name": "graph_database_search",  "description": "Good for queries involving rule relationships, specific rule IDs, or full-text search within rules and clarifications in Neo4j", "retriever": graph_retriever},
            ],
            default_retriever=vector_retriever,
            default_prompt=multi_retrieval_prompt,
            verbose=True,
        )

    except Exception as e:
        print(f"Could not set up actual vector retriever: {e}")
        print("Falling back to DummyVectorRetriever. Results may not be meaningful.")
        # vector_retriever is already DummyVectorRetriever if this block fails
        # graph_retriever will use DummyVectorRetriever internally
        graph_retriever = Neo4jHybridRetriever(driver=driver, vector_retriever=vector_retriever, limit=5)
        # Chains will use the dummy or partially initialized retrievers
        vector_chain  = create_retrieval_chain(vector_retriever, combine_documents_chain)
        graph_chain   = create_retrieval_chain(graph_retriever, combine_documents_chain)
        multi_chain = MultiRetrievalQAChain.from_retrievers(
            llm=llm,
            retriever_infos=[
                {"name": "vector_search", "description": "ADA embeddings + ChromaDB", "retriever": vector_retriever},
                {"name": "graph_database_search",  "description": "Exact-match Neo4j", "retriever": graph_retriever},
            ],
            default_retriever=vector_retriever,
            default_prompt=multi_retrieval_prompt,
            verbose=True,
        )
    # --- End of vector_retriever setup ---

    run_retrieval_comparison(QUERIES)
    
    driver.close()