In [111]:
import os
import json
import ast
import re
from pprint import pprint
from typing import TypedDict, List, Optional
from dotenv import load_dotenv

import litellm
from neo4j import GraphDatabase, Driver
from langchain_litellm import ChatLiteLLM
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint.memory import MemorySaver

# Neo4j GraphRAG library
from neo4j_graphrag.retrievers import VectorRetriever, HybridRetriever
from neo4j_graphrag.embeddings.base import Embedder  # For custom embedder
from neo4j_graphrag.types import HybridSearchRanker
#from neo4j_graphrag.types import VectorIndexConfig

# Embedding model (synchronized with ingestor)
from sentence_transformers import SentenceTransformer

In [112]:
# --- Helper for Embeddings ---
# Synchronized with ingestor: Uses sentence-transformers/all-mpnet-base-v2 for 768-dim vectors
embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

def get_embedding(text: str) -> List[float]:
    """
    Generates 768-dimensional embeddings using the same model as ingestion.
    Embeddings are L2-normalized for cosine similarity.
    """
    embeddings = embedding_model.encode(text)
    return embeddings.tolist()  # Returns list for Neo4j compatibility
# --- Custom Embedder for GraphRAG Compatibility ---
class CustomEmbedder(Embedder):
    """
    Wraps SentenceTransformer as a GraphRAG Embedder to satisfy the interface.
    """
    def __init__(self, model_name: str = 'sentence-transformers/all-mpnet-base-v2'):
        self.model = SentenceTransformer(model_name)
    
    def embed_query(self, text: str) -> List[float]:
        return get_embedding(text)
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        embeddings = self.model.encode(texts)
        return [emb.tolist() for emb in embeddings]

# Instantiate
custom_embedder = CustomEmbedder()
print("Custom Embedder initialized successfully.") 

Custom Embedder initialized successfully.


In [113]:
# === 1. Load Config & Set Up Clients ===
load_dotenv()

# --- LiteLLM Configuration (Fixed Deprecation) ---
os.environ['LITELLM_LOG'] = 'INFO'  # Replaces litellm.set_verbose
litellm.api_key = os.getenv("GROQ_API_KEY")
litellm.model_list = [
    {
        #"model_name": "groq/llama-3.3-70b-versatile",
        "model_name": "groq/llama-3.1-8b-instant",
        "litellm_params": {"model": "groq/llama-3.1-8b-instant", "api_key": os.getenv("GROQ_API_KEY")}
    }
]

# --- LangChain/LangGraph LLM ---
llm = ChatLiteLLM(model="groq/llama-3.1-8b-instant")

# --- Neo4j Connection ---
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USER = os.getenv("NEO4J_USER")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j")

try:
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    driver.verify_connectivity()
    print("Neo4j connection successful.")
    
    # Verify vector indexes (GraphRAG best practice)
    with driver.session(database=NEO4J_DATABASE) as session:
        vector_indexes = session.run("SHOW VECTOR INDEXES YIELD name, state").data()
        fulltext_indexes = session.run("SHOW FULLTEXT INDEXES YIELD name, state").data()
        for idx in vector_indexes:
            print(f"Vector Index '{idx['name']}': {idx['state']}")
        for idx in fulltext_indexes:
            print(f"Fulltext Index '{idx['name']}': {idx['state']}")
        if not any(idx['name'] == 'fact_embeddings' for idx in vector_indexes):
            raise ValueError("Required 'fact_embeddings' index not found. Run ingestor first.")
            
except Exception as e:
    print(f"Failed to connect to Neo4j or verify indexes: {e}")

Neo4j connection successful.
Vector Index 'fact_embeddings': ONLINE
Vector Index 'section_embeddings': ONLINE
Fulltext Index 'text_index': ONLINE


In [114]:
# === 2. LangGraph State Definition ===

# Evolved from the state in your retriever.ipynb
class RetrievalState(TypedDict):
    question: str
    original_question: str
    current_nodes: List[str]  # Facts or sections found in the last step
    notebook: List[str]       # The full "context" of all facts found so far
    traversal_count: int
    traversal_limit: int

In [115]:
# === 3. LangGraph Agent Nodes ===
def initial_discovery(state: RetrievalState) -> None:
    """
    GraphRAG Entrypoint: Uses hybrid search (via neo4j-graphrag) to find initial nodes.
    Includes Vector Index & Full-text Index.
    """
    print("--- Node: initial_discovery ---")
    question = state['question']
    question_embedding = get_embedding(question)

    current_nodes = []
    with driver.session(database=NEO4J_DATABASE) as session:
        try:
            hybrid_retriever = HybridRetriever(
                driver=driver,
                vector_index_name="fact_embeddings",
                fulltext_index_name="text_index",
                embedder=custom_embedder,
                neo4j_database=NEO4J_DATABASE
            )
            raw_results = hybrid_retriever.get_search_results(
            query_text=question,
            top_k=5,
            effective_search_ratio=1,  # Candidate pool multiplier
            ranker=HybridSearchRanker.LINEAR,  # Or LINEAR with alpha=0.7
            alpha=0.1  # For NAIVE; set to 0.7 for LINEAR
        )
            #results = retriever.search(query_text=question, top_k=5)
            parsed_content = []
            for r in raw_results.records:
                try:
                    content_dict = r['node']
                    score = r['score']
                    content_text = content_dict.get(next(iter(content_dict)),'').strip()
                    if score >= 0.5 and content_text:
                        parsed_content.append((content_text,score))
                except (ValueError, SyntaxError) as e:
                    print(f"  Warning: Failed to parse content: {r['node'][:100]}... | Error: {e}")
                    continue
            current_nodes = [item[0] for item in parsed_content]
            scores = [item[1] for item in parsed_content]
            print(f"  Vector hits (facts): {len(parsed_content)} (all scores: {scores})")

        except Exception as e:
            print(f"Error during retrieval: {e}")
    # Store clean text only
    state['current_nodes'] = current_nodes
    state['notebook'].extend(["Initial relevant facts:"] + current_nodes)
    return state
    

In [116]:
# === 3. LangGraph Agent Nodes ===
def initial_discovery1(state: RetrievalState) -> RetrievalState:
    """
    GraphRAG Entrypoint: Uses vector search (via neo4j-graphrag) to find initial nodes.
    Includes full-text fallback.
    """
    print("--- Node: initial_discovery ---")
    question = state['question']
    question_embedding = get_embedding(question)
    
    current_nodes = []
    with driver.session(database=NEO4J_DATABASE) as session:  # Session for full-text fallback
        # 1. Primary: GraphRAG Vector Search on FactNodes (using library)
        try:
            # Initialize retriever with driver (not session), return_properties for fact access
            retriever = VectorRetriever(
                driver=driver,
                index_name="fact_embeddings",
                return_properties=["fact"],  # Fetch specific property
                neo4j_database=NEO4J_DATABASE
            )
            # Search with top_k; filter by score >= 0.7 manually
            results = retriever.search(query_vector=question_embedding, top_k=5)
            parsed_facts = []
            for r in results.items:
                try:
                    fact_dict = ast.literal_eval(r.content)
                    fact_text = fact_dict.get('fact', '').strip()
                    score = r.metadata.get("score", 0)
                    if score >= 0.7 and fact_text:
                        print(fact_text)
                        parsed_facts.append(fact_text,score)
                except (ValueError, SyntaxError) as e:
                    print(f"  Warning: Failed to parse content: {r.content[:100]}... | Error: {e}")
                    continue
            
            current_nodes = parsed_facts
            all_scores = [r.metadata.get('score', 0) for r in results.items]
            print(f"  Vector hits (facts): {len(current_nodes)} | Scores: {all_scores}")
        except Exception as e:
            print(f"  Error in fact vector search: {e}")

        if not current_nodes:
            # 1b. Vector Fallback: Search SectionChunks
            try:
                print("  No fact hits, trying section chunks...")
                retriever = VectorRetriever(
                    driver=driver,
                    index_name="section_embeddings",
                    return_properties=["text"],
                    neo4j_database=NEO4J_DATABASE
                )
                results = retriever.search(query_vector=question_embedding, top_k=5)
                parsed_texts = []
                for r in results.items:
                    try:
                        text_dict = ast.literal_eval(r.content)
                        text_value = text_dict.get('text', '').strip()
                        score = r.metadata.get('score', 0)
                        if score >= 0.7 and text_value:
                            parsed_texts.append(text_value)
                    except (ValueError, SyntaxError):
                        continue
                current_nodes = parsed_texts
                all_scores = [r.metadata.get('score', 0) for r in results.items]
                print(f"  Vector hits (chunks): {len(current_nodes)} | Scores: {all_scores}")            
            except Exception as e:
                print(f"  Error in chunk vector search: {e}")

        if not current_nodes:
            # 2. Fallback: Full-Text Search (raw Cypher)
            print("  No vector hits, using full-text fallback...")
            fulltext_hits = session.run(
                """
                CALL db.index.fulltext.queryNodes('text_index', $question) 
                YIELD node, score
                RETURN COALESCE(node.fact, node.text) AS text, score
                ORDER BY score DESC
                LIMIT 5
                """,
                question=question
            ).data()
            current_nodes = [r['text'] for r in fulltext_hits]
            print(f"  Full-text hits: {len(current_nodes)}")
                
    # Store clean text only
    state['current_nodes'] = current_nodes
    state['notebook'].extend(["Initial relevant facts:"] + current_nodes)
    return state


In [117]:
def hop_analyzer(state: RetrievalState) -> RetrievalState:
    """
    Agentic Traversal: Performs omni-directional hops across any nodes/relationships.
    Explores 1-2 hop neighbors, excluding notebook content.
    """
    print("--- Node: hop_analyzer ---")
    if state['traversal_count'] >= state['traversal_limit']:
        print("  Traversal limit reached. Stopping hops.")
        state['current_nodes'] = []
        return state

    state['traversal_count'] += 1
    
    with driver.session(database=NEO4J_DATABASE) as session:
        # Get IDs of current nodes for traversal (supports mixed types)
        current_ids_result = session.run(
            """
            UNWIND $current_nodes AS node_text
            MATCH (n)
            WHERE (n.fact = node_text OR n.text = node_text)
            RETURN collect(elementId(n)) AS current_ids
            """,
            current_nodes=state['current_nodes']
        ).single()
        print(f"  Current node IDs for traversal: {current_ids_result}")
        current_ids = current_ids_result['current_ids'] if current_ids_result else []
        
        if not current_ids:
            print("  No valid current node IDs found.")
            state['current_nodes'] = []
            return state
        
        # Omni-directional traversal: Any node, any relationship, 1-2 hops
        new_nodes = session.run(
            """
            UNWIND $current_ids AS start_id
            MATCH (start) WHERE elementId(start) = start_id
            MATCH p = (start)-[*1..2]-(neighbor)
            WHERE NOT elementId(neighbor) IN $current_ids
            AND NOT (neighbor.fact IN $notebook OR neighbor.text IN $notebook)
            RETURN DISTINCT COALESCE(neighbor.fact, neighbor.text) AS new_content
            LIMIT 10
            """,
            current_ids=current_ids,
            notebook=state['notebook']
        ).data()
        print(f"  New nodes found: {len(new_nodes)}")
        found_contents = [r['new_content'] for r in new_nodes if r['new_content']]
        state['current_nodes'] = found_contents
        
        if found_contents:
            state['notebook'].extend(["Found related contents via omni-traversal:"] + found_contents)
            print(f"  Found {len(found_contents)} new related contents.")
        else:
            print("  No new related contents found.")
            
    return state

def context_manager(state: RetrievalState) -> RetrievalState:
    """
    Intelligent Context Manager: Summarizes the notebook if it gets too large.
    """
    print("--- Node: context_manager ---")
    current_context_length = len(json.dumps(state['notebook']))
    
    if len(state['notebook']) > 20:  # Additional trigger for length
        print(f"  Notebook length ({len(state['notebook'])}) > 20. Pruning via relevance...")
        try:
            relevance_prompt = f"""
            Original Question: {state['original_question']}
            Notebook: {json.dumps(state['notebook'], indent=2)}
            
            Rank notebook items by relevance (1-10 scale) to the question.
            Keep only items >=7, discard others.
            Return filtered list as JSON: {{"filtered_notebook": ["item1", "item2"]}}
            """
            response = llm.invoke(relevance_prompt)
            raw_output = response.content.strip()
            json_match = re.search(r'json\s*(\{.*?\})\s*|({.*?})', raw_output, re.DOTALL)
            if json_match:
                json_str = json_match.group(1) or json_match.group(2)
            else:
                json_str = raw_output
            print(json_str)
            try:
                data = json.loads(json_str)
                filtered = data.get("filtered_notebook", [])
                if filtered:
                    state['notebook'] = filtered
                    print(f"  Notebook pruned to {len(state['notebook'])} items via LLM relevance.")
                return state
            except json.JSONDecodeError as e:
                print(f" Failed to parse LLM JSON: {e}")
                pass    
        except Exception as e:
            print(f"  Error in relevance pruning: {e}. Falling back to summarization.")
    # Example trigger: 4000 characters
    if current_context_length > 4000:
        print(f"  Context length ({current_context_length}) > 4000. Summarizing...")
        try:
            response = llm.invoke(
                f"Concisely summarize the following facts into a single paragraph: \n\n"
                f"{json.dumps(state['notebook'])}"
            )
            state['notebook'] = ["Context summarized:", response.content]
            print("  Context successfully summarized.")
        except Exception as e:
            print(f"  Error summarizing context: {e}. Pruning instead.")
            # Fallback for summarization: simple pruning
            state['notebook'] = state['notebook'][-10:] # Keep last 10 items
    return state

In [118]:
def evaluate_answer(state: RetrievalState) -> RetrievalState:
    """
    Analyzer Node: Decides if the answer is sufficient, needs a deepdive,
    or needs more hops.
    (This is an evolution of your evaluate_answer node)
    """
    print("--- Node: evaluate_answer ---")
    
    # Loop fallback (if analyzer is hit after limit)
    if state['traversal_count'] >= state['traversal_limit']:
        print("  Traversal limit reached. Force-ending.")
        return {"decision": "sufficient"}
    
    # If no new nodes were found on the last hop, we're stuck.
    if not state['current_nodes']:
        print("  No new nodes found. Force-ending.")
        return {"decision": "sufficient"}

    try:
        prompt = f"""
        You are an expert evaluator for a RAG agent.
        Original Question: {state['original_question']}
        
        Current Facts Notebook:
        {json.dumps(state['notebook'], indent=2)}
        
        Based *only* on the facts in the notebook, is the Original Question
        fully and completely answered?
        
        Choose one of the following decisions:
        1. 'sufficient': The answer is complete and no more info is needed.
        2. 'hop_more': The answer is incomplete. We need to find more *related* facts.
        3. 'deepdive': The answer is incomplete and we need a *new line of questioning*
           to find different, specific information.
           
        Respond with a single JSON object: {{"decision": "your_choice"}}
        """
        
        response = llm.invoke(prompt)
        decision_data = json.loads(response.content)
        print(f"  LLM Decision: {decision_data['decision']}")
        return decision_data
    
    except Exception as e:
        print(f"  Error evaluating answer: {e}. Defaulting to 'sufficient'.")
        return {"decision": "sufficient"}

def replan_query(state: RetrievalState) -> RetrievalState:
    """
    Query Modification Node: Creates a new, more specific query for a deepdive.
    """
    print("--- Node: replan_query ---")
    try:
        prompt = f"""
        You are a query replanner.
        Original Question: {state['original_question']}
        Current Context: {json.dumps(state['notebook'], indent=2)}
        
        The current context is not enough. What *new, specific query*
        should we use to find the missing information?
        
        Respond with a single JSON object: {{"new_query": "your_new_query"}}
        """
        
        response = llm.invoke(prompt)
        query_data = json.loads(response.content)
        
        # This is the "new trajectory"
        state['question'] = query_data['new_query']
        state['notebook'].append(f"RE-PLANNING: New query: {query_data['new_query']}")
        state['current_nodes'] = [] # Clear nodes for new discovery
        
        print(f"  New query: {state['question']}")
        
    except Exception as e:
        print(f"  Error replanning query: {e}. Reverting to original.")
        state['question'] = state['original_question']

    return state

def compile_final_answer(state: RetrievalState) -> RetrievalState:
    """
    Calls the LLM one last time to generate a final answer from the notebook.
    """
    print("--- Node: compile_final_answer ---")
    try:
        response = llm.invoke(
            f"Original Question: {state['original_question']}\n\n"
            f"Use *only* the following facts to answer the question:\n"
            f"{json.dumps(state['notebook'], indent=2)}\n\n"
            f"Answer:"
        )
        # We'll add the final answer back into the 'notebook'
        state['notebook'].append(f"Final Answer: {response.content}")
        print(f"  Final Answer: {response.content}")
    except Exception as e:
        print(f"  Error compiling final answer: {e}")
        state['notebook'].append(f"Error compiling final answer: {e}")
        
    return state

In [119]:
# === 4. LangGraph Graph Definition ===

def build_graph():
    workflow = StateGraph(RetrievalState)

    # 1. Add Nodes
    workflow.add_node("initial_discovery", initial_discovery)
    workflow.add_node("hop_analyzer", hop_analyzer)
    workflow.add_node("context_manager", context_manager)
    workflow.add_node("evaluate_answer", evaluate_answer)
    workflow.add_node("replan_query", replan_query)
    workflow.add_node("compile_final_answer", compile_final_answer)

    # 2. Set Entry Point
    workflow.set_entry_point("initial_discovery")

    # 3. Add Edges
    workflow.add_edge("initial_discovery", "hop_analyzer")
    workflow.add_edge("hop_analyzer", "context_manager")
    workflow.add_edge("context_manager", "evaluate_answer")
    
    # 4. Add Conditional Edges (The "Agentic" part)
    workflow.add_conditional_edges(
        "evaluate_answer",
        lambda x: x["decision"],
        {
            "sufficient": "compile_final_answer",
            "hop_more": "hop_analyzer", # Main loop
            "deepdive": "replan_query" # New trajectory
        }
    )
    
    # 5. Add New Trajectory Edge
    workflow.add_edge("replan_query", "initial_discovery")
    
    # 6. Set End Point
    workflow.add_edge("compile_final_answer", END)

    # 7. Compile
    return workflow.compile(checkpointer=MemorySaver())

In [120]:
# === 5. Main Execution Function ===
def main():
    app = build_graph()
    
    """from IPython.display import Image, display
    display(Image(app.get_graph().draw_mermaid_png()))"""
    # Define the initial state for the query
    inputs = {
        "question": "What are dos & don'ts associated with Delivery Instruction Form & what is it?",
        "original_question": "What are dos & don'ts associated with Delivery Instruction Form & what is it?",
        "notebook": [],
        "current_nodes": [],
        "traversal_count": 0,
        "traversal_limit": 5
    }
    
    config = {"configurable": {"thread_id": "1"}}
    
    print("--- Running Agentic Retrieval ---")
    
    # Stream the events
    try:
        for event in app.stream(inputs, config=config, stream_mode="values"):
            print("\n" + "="*40)
            pprint(event)
            print("="*40)
        
        print("--- Run Complete ---")
        final_state = app.get_state(config)
        print(f"Final Notebook:\n{json.dumps(final_state.values['notebook'], indent=2)}")
    
    except Exception as e:
        print(f"An error occurred during the agent run: {e}")
    finally:
        if 'driver' in globals():
            driver.close()
            print("\nNeo4j connection closed.")

In [121]:
main()

--- Running Agentic Retrieval ---

{'current_nodes': [],
 'notebook': [],
 'original_question': "What are dos & don'ts associated with Delivery "
                      'Instruction Form & what is it?',
 'question': "What are dos & don'ts associated with Delivery Instruction Form "
             '& what is it?',
 'traversal_count': 0,
 'traversal_limit': 5}
--- Node: initial_discovery ---
  Vector hits (facts): 5 (all scores: [0.998387154600477, 0.8929111361121046, 0.7874265564005567, 0.7832103993939485, 0.7128688043851253])

{'current_nodes': ['The correct ID of the authorised intermediary must be '
                   'filled in the Delivery Instruction Form.',
                   'Do not leave signed blank delivery instruction slips with '
                   'anyone while meeting pay in obligations.',
                   'Basics of Financial Markets\n'
                   '14\n'
                   ' /head2right Ensure that you do not undertake deals on '
                   'behalf of othe

In [122]:
################################################################ 
srd

NameError: name 'srd' is not defined

In [None]:
main()

--- Running Agentic Retrieval ---

{'current_nodes': [],
 'notebook': [],
 'original_question': "What are dos & don'ts associated with Delivery "
                      'Instruction Form & what is it?',
 'question': "What are dos & don'ts associated with Delivery Instruction Form "
             '& what is it?',
 'traversal_count': 0,
 'traversal_limit': 5}
--- Node: initial_discovery ---
  Vector hits (facts): 5 | Scores: [0.7980027198791504, 0.7949621677398682, 0.7851321697235107, 0.7766571044921875, 0.7348215579986572]

{'current_nodes': ['Do not sign blank delivery instruction slips while meeting '
                   'security payin obligations.',
                   'Do not leave signed blank delivery instruction slips with '
                   'anyone while meeting pay in obligations.',
                   'The correct ID of the authorised intermediary must be '
                   'filled in the Delivery Instruction Form.',
                   'The Demat delivery instruction slip shou