# Module 5 - GraphRAG and Agent

In the following notebook are we creating an agent on top of the Graph.

Import our usual suspects (and some more...)

In [None]:
import os
import pandas as pd
from dotenv import load_dotenv
from graphdatascience import GraphDataScience
from neo4j import Query, GraphDatabase, RoutingControl, Result
from langchain.schema import HumanMessage
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.prebuilt import create_react_agent
from openai import OpenAI
from typing import List, Optional
from pydantic import BaseModel, Field, validator
import functools
from langchain_core.tools import tool
import gradio as gr
import time
from json import loads, dumps
from typing_extensions import TypedDict
from typing import Literal
from langgraph.graph import StateGraph, Graph, START, END
from langgraph.checkpoint.memory import MemorySaver

## Setup

Load env variables

In [None]:
env_file = 'ws.env'

In [None]:
if os.path.exists(env_file):
    load_dotenv(env_file, override=True)

    # Neo4j
    HOST = os.getenv('NEO4J_URI')
    USERNAME = os.getenv('NEO4J_USERNAME')
    PASSWORD = os.getenv('NEO4J_PASSWORD')
    DATABASE = os.getenv('NEO4J_DATABASE')

    # AI
    OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
    os.environ['OPENAI_API_KEY']=OPENAI_API_KEY
    LLM = os.getenv('LLM')
    EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL')
else:
    print(f"File {env_file} not found.")

Setup connection to the database with the [Python Driver](https://neo4j.com/docs/python-manual/5/).

In [None]:
driver = GraphDatabase.driver(
    HOST,
    auth=(USERNAME, PASSWORD)
)

Test the connection

In [None]:
driver.execute_query(
    """
    MATCH (n) RETURN COUNT(n) as Count
    """,
    database_=DATABASE,
    routing_=RoutingControl.READ,
    result_transformer_= lambda r: r.to_df()
)

Test whether we got our constraints

In [None]:
schema_result_df  = driver.execute_query(
    'show indexes',
    database_=DATABASE,
    routing_=RoutingControl.READ,
    result_transformer_= lambda r: r.to_df()
)

In [None]:
schema_result_df.head(100)

## Agents with GraphRAG

### Lets create a Retrieval agent

In [None]:
client = OpenAI()

In [None]:
llm = ChatOpenAI(model_name=LLM, temperature=0)

In [None]:
llm.model_name

In [None]:
embedding_model = OpenAIEmbeddings(
    model=EMBEDDINGS_MODEL,
    openai_api_key=OPENAI_API_KEY
)

In [None]:
embedding_model.model

### Tool 1 - Retrieve Products

Define a function that retrieves products from the Database

In [None]:
def retrieve_products() -> pd.DataFrame:
    """Retrieve the products in the database. Products are specified with name. """
    return driver.execute_query(
        """
        MATCH (p:ProductType)
        RETURN p.name as name
        """,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        result_transformer_= lambda r: r.to_df(),
    )['name'].tolist()

In [None]:
retrieve_products()

### Tool 2 - Map Products to Database

To map a product from some a question we need to map it to the database.

In [None]:
map_products_prompt = """
As an intelligent assistant, your primary objective is to map a product name to product names in the database.

Examples:
#####
Product: savings account. 
Database Products: ['SpaarRekening', 'DirectRekening', 'Kortlopende Reis', 'BeleggersRekening', 'RaboBusiness Banking']
Assistant: Product: SpaarRekening
#####
#####
Product: Direct Rekening. 
Database Products: ['SpaarRekening', 'DirectRekening', 'Kortlopende Reis', 'BeleggersRekening', 'RaboBusiness Banking']Assistant: Customer: Jan Blok
Assistant: Product: DirectRekening

#####
#####
Product: Reis verzekering. 
Database Products: ['SpaarRekening', 'DirectRekening', 'Kortlopende Reis', 'BeleggersRekening', 'RaboBusiness Banking']Assistant: Customer: Jan Blok
Assistant: Product: Kortlopende Reis
#####
"""

In [None]:
def map_product_to_database_products(product) -> str:
    """Map products from the user question to the actual products in the database."""

    response = client.beta.chat.completions.parse(
        model=LLM,
        temperature=0,
        messages=[
            {"role": "system", "content": map_products_prompt},
            {"role": "user", "content": "Product: " + product},
            {"role": "user", "content": "Database Products: " + str(retrieve_products())},
            
        ],
#        response_format=DefinitionList,
    )
    return response.choices[0].message.content 

In [None]:
map_product_to_database_products('savings account')

### Tool 3 - Retrieve product documents

Define a function that retrieves a document from a specific product

In [None]:
def retrieve_document_from_product(product_name) -> pd.DataFrame:
    """Retrieve the documents of products in the database. Products are specified with their name. """
    return driver.execute_query(
        """
        MATCH (p:ProductType)<-[:RELATED_TO]-(d:Document)
        WHERE LOWER(p.name) = LOWER($product_name)
        RETURN d.file_name
        """,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        product_name = product_name,
        result_transformer_= lambda r: r.to_df(),
    ).iloc[0]['d.file_name']

In [None]:
retrieve_document_from_product('SpaarRekening')

### Tool 4 - GraphRAG Document Search

The following function performs GraphRAG but only on a specific document.

In [None]:
def get_context_graphrag(search_prompt, document):
    query_vector = embedding_model.embed_query(search_prompt)

    similarity_query = """ 
        CALL db.index.vector.queryNodes("chunk-embeddings", 30, $query_vector) YIELD node, score
        WITH node as chunk, score ORDER BY score DESC
        CALL (chunk) {
            MATCH (chunk)-[r:OVERLAPPING_DEFINITIONS]-(overlapping_chunk:Chunk)
            WHERE r.overlap > 3
            RETURN collect(overlapping_chunk) AS overlapping_chunks
        }
        WITH [chunk] + overlapping_chunks AS chunks
        UNWIND chunks as chunk
        MATCH (d:Document{file_name: $document})<-[:PART_OF]-(chunk)
        RETURN d.file_name as file_name, chunk.id as chunk_id, chunk.page as page, chunk.chunk_eng AS chunk
       """
    results_1 = driver.execute_query(
        similarity_query,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        query_vector=query_vector,
        document=document,
        result_transformer_= lambda r: r.to_df()
    )

    chunk_ids = list(set(results_1['chunk_id'].to_list()))

    definition_query = """    
       CALL db.index.vector.queryNodes("definition-embeddings", 5, $query_vector) YIELD node, score
            WITH node as definition, score ORDER BY score DESC
            WHERE definition.degree < 20
            WITH definition LIMIT 1
            MATCH (definition)<-[:MENTIONS]-(chunk:Chunk)
            WHERE NOT (chunk.id IN $chunk_ids)
            WITH chunk LIMIT 3
            MATCH (d:Document{file_name: $document})<-[:PART_OF]-(chunk)
            RETURN d.file_name as file_name, chunk.id as chunk_id, chunk.page as page, chunk.chunk_eng AS chunk
    """
    results_2 = driver.execute_query(
        definition_query,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        chunk_ids=chunk_ids,
        document=document,
        query_vector=query_vector,
        result_transformer_= lambda r: r.to_df()
    )

    results = pd.concat([results_1,results_2]).drop_duplicates()
    results = results.to_json(orient="records")
    parsed = loads(results)
    context = dumps(parsed, indent=2)
    return context

In [None]:
def generate_prompt(search_prompt, context):
    prompt_template = """

    You are a chatbot on Rabobank product. Your goal is to help people with questions on product policies.  
    A user will come to you with questions on their policy. Their questions must be answered based on the relevant documents of the policy.
    Respond in English. 

    The question is the following: 
    {search_prompt}
    
    Always respond in the language in which the question was asked. So, do not respond in a different language.
    
    The context is the following: 
    {context}

    Please explain your answer as thorough as possbile based on the context above. Don't come up with anything yourself.
    
    Please end your message with listing your sources with file name and page number. 
    """
    prompt = PromptTemplate.from_template(prompt_template)
    
    theprompt = prompt.format_prompt(search_prompt=search_prompt, context=context)
    return theprompt

In [None]:
context = get_context_graphrag("What are the rules for shared savings account?", "Rabo SpaarRekening 2020.pdf")

In [None]:
print(context)

In [None]:
def perform_search_in_document(document, search_prompt) -> [str, str]:
    """Peform a search in the document to search relevant text to answer a user question. The document first needs to be determined before a search should be performed."""
    context = get_context_graphrag(document, search_prompt)
    return context

In [None]:
def answer_question_in_document(question, document) -> str:
    """This function is answering a question based on a search in a document (vector search on document). Document and question both need to be provided."""
    context = perform_search_in_document(question, document)
    theprompt = generate_prompt(question, context)
    return llm(theprompt.to_messages()).content

In [None]:
result = answer_question_in_document("What are the rules for shared savings account?", "Rabo SpaarRekening 2020.pdf")

In [None]:
print(result)

### Tool 5 - Retrieve products from Customer

Define a function that retrieves the products from a customer

In [None]:
def retrieve_products_of_customers(customer_id) -> pd.DataFrame:
    """Retrieve the products of a customer in the database. Customers are specified with their id. """
    return driver.execute_query(
        """
        MATCH (c:Customer)-[:HAS_PRODUCT]->(p:Product)
        WHERE c.id = $customer_id
        RETURN p.id as product_id, p.name as product_name
        """,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        customer_id = customer_id,
        result_transformer_= lambda r: r.to_df(),
    )

In [None]:
retrieve_products_of_customers(16)

### Tool 6 - Retrieve information from a product

Function retrieving all properties of a product

In [None]:
def retrieve_information_from_product(product_id) -> pd.DataFrame:
    """Retrieve the information of a product in the database. Product is specified with id."""
    result = driver.execute_query(
        """
       MATCH (p:Product{id: $product_id})
        RETURN properties(p) as props
        """,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        product_id = product_id,
        result_transformer_= lambda r: r.to_df(),
    )
    return result.iloc[0]['props']

In [None]:
retrieve_information_from_product("ef31587f-5c96-4ab2-99e6-99f3b6fa88e8")

### Tool 7 - Retrieve Customer ID based on full name

Retrieve customer id from the database based on name

In [None]:
def retrieve_customer_id_from_database_based_on_full_name(full_name) -> pd.DataFrame:
    """Retrieve customers from the database database based on full_name. customer_id is returned."""
    
    query = """MATCH (c:Customer) WHERE c.name = $name RETURN c.id"""

    result = driver.execute_query(
        query,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        name = full_name,
        result_transformer_= lambda r: r.to_df(),
    )
    return result

In [None]:
retrieve_customer_id_from_database_based_on_full_name("Lucas Van den Berg")

## Setting up the Agent

In [None]:
llm = ChatOpenAI(model_name=LLM, temperature=0)

In [None]:
response = llm.invoke([HumanMessage(content="hi!")])
response.content

In [None]:
tools = [
    retrieve_products,
    map_product_to_database_products,
    retrieve_document_from_product,
    answer_question_in_document,
    retrieve_products_of_customers,
    retrieve_information_from_product,
    retrieve_customer_id_from_database_based_on_full_name,
]

llm_with_tools = llm.bind_tools(tools)

## Running Agents with LangGraph

We are creating a simple [react agent](https://langchain-ai.github.io/langgraph/reference/agents/#langgraph.prebuilt.chat_agent_executor.create_react_agent) using Langgraph 

In [None]:
agent_executor = create_react_agent(llm, tools)

In [None]:
response = agent_executor.invoke({"messages": [HumanMessage(content="hi!")]})

In [None]:
response["messages"]

### Test Tool Calling

Check what Tool will be called on which question. 

In [None]:
response = llm_with_tools.invoke([HumanMessage(content="What products does Jan Kok have?")])

print(f"ContentString: {response.content}")
print(f"ToolCalls: {response.tool_calls}")

In [None]:
response = llm_with_tools.invoke([HumanMessage(content="What is my expiration date of my savings account? My name is Jan Kok")])

print(f"ContentString: {response.content}")
print(f"ToolCalls: {response.tool_calls}")

In [None]:
response = llm_with_tools.invoke([HumanMessage(content="I got a question on my savings account")])

print(f"ContentString: {response.content}")
print(f"ToolCalls: {response.tool_calls}")

In [None]:
response = llm_with_tools.invoke([HumanMessage(content="I got a question on my savings account, what are the rules for a joint account?")])

print(f"ContentString: {response.content}")
print(f"ToolCalls: {response.tool_calls}")

#### Run some examples! 

Let's run some examples with the agent.

In [None]:
def ask_to_agent(question):
    for step in agent_executor.stream(
        {"messages": [HumanMessage(content=question)]},
        stream_mode="values",
    ):
        step["messages"][-1].pretty_print()

In [None]:
question = "What Products does Jan Kok have?"

In [None]:
ask_to_agent(question)

In [None]:
question = "I got a question on my savings account, what are the rules for a joint account?"

In [None]:
ask_to_agent(question)

In [None]:
question = "When is my travel insurance exprired? My name is Daan Visser"

In [None]:
ask_to_agent(question)

In [None]:
question = "When is my IBAN of my Saving account? My name Lucas van den Berg"

In [None]:
ask_to_agent(question)

## Chatbot

Now create a chatbot with the agent providing the responses

In [None]:
def user(user_message, history):
    if history is None:
        history = []
    history.append({"role": "user", "content": user_message})
    return "", history

def get_answer(history):
    steps = []
    full_prompt = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in history])
    
    for step in agent_executor.stream(
            {"messages": [HumanMessage(content=full_prompt)]},
            stream_mode="values",
    ):
        step["messages"][-1].pretty_print()
        steps.append(step["messages"][-1].content)
    
    return steps[-1]

def bot(history):
    bot_message = get_answer(history)
    history.append({"role": "assistant", "content": ""})

    for character in bot_message:
        history[-1]["content"] += character
        time.sleep(0.01)
        yield history

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(
        label="Chatbot on a Graph",
        avatar_images=[
            "https://png.pngtree.com/png-vector/20220525/ourmid/pngtree-concept-of-facial-animal-avatar-chatbot-dog-chat-machine-illustration-vector-png-image_46652864.jpg",
            "https://d-cb.jc-cdn.com/sites/crackberry.com/files/styles/larger/public/article_images/2023/08/openai-logo.jpg"
        ],
        type="messages", 
    )
    msg = gr.Textbox(label="Message")
    clear = gr.Button("Clear")

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot], chatbot
    )

    clear.click(lambda: [], None, chatbot, queue=False)

demo.queue()
demo.launch(share=True)

If you want to have the light-mode for the chatbot paste the following after the URL: /?__theme=light

### LangGraph Workflow

Within LangGraph you can also define [Agent Workflows](https://langchain-ai.github.io/langgraph/tutorials/workflows/). Below we are quickly setting up such a workflow. 

In [None]:
retrieve_product_customer_prompt = """
As an intelligent assistant, your primary objective is to find either a customer name or product from the text submitted. 
The goal is to retrieve either the product on which someone is asking a question. Or the customer name such that we can find the products from their. 
Please only return the product name that is extracted. 

Examples:
#####
User: I got a question about my savings account. 
Assistant: Product: savings account
#####
#####
User: My name is Jan Blok and I got a question. 
Assistant: Customer: Jan Blok
#####
#####
User: What is the policy on account? 
Assistant: Need more information
#####
"""

def retrieve_product_or_customer_from_text(question):
    """Retrieve either products or customers from the text given by the user"""

    response = client.beta.chat.completions.parse(
        model=LLM,
        temperature=0,
        messages=[
            {"role": "system", "content": retrieve_product_customer_prompt},
            {"role": "user", "content": question},
        ],
#        response_format=DefinitionList,
    )
    return response.choices[0].message.content 

In [None]:
class State(TypedDict):
    input: str
    question: str
    product: str
    customer: str
    document: str
    answer: str

from langgraph.types import interrupt, Command

# no-op node that should be interrupted on
def human_feedback(state: State):
    pass


def retrieve_product_customer(state):
    result = retrieve_product_or_customer_from_text(state['input'])
    if 'Customer: ' in result:
        customer = result.split('Customer: ')[1]
        print('It looks like you are the following customer: ' + customer)
        state['customer'] = customer
    elif 'Product: ' in result:
        product = result.split('Product: ')[1]
        print('It sounds like you got a question on the following product: ' + product)
        state['product'] = product
    return state

def user_feedback(state):
    print("I couldn't find information to start the flow. Could you specify a product on which you have questions or your customer name?")
    return state 
    
def map_product_to_db(state):
    result = map_product_to_database_products(state['product'])
    product = result.split('Product: ')[1]
    state['product'] = product
    print("I have found the following product in the database: " + product)
    return state

def retrieve_document(state):
    result = retrieve_document_from_product(state['product'])
    state['document'] = result
    print("I found the following document on the product: " + result)
    return state

def tools_condition(state) -> Literal["user_feedback", "map_product"]: 
    if (state.get('product') is None):
        return "user_feedback"
    else: 
        return "map_product"

def answer_question(state):
    context = perform_search_in_document(state['question'], state['document'])
    theprompt = generate_prompt(state['question'], context)
    state['answer'] = llm(theprompt.to_messages()).content
    return state 
    
# Build graph
builder = StateGraph(State)
builder.add_node("retrieve_product_customer", retrieve_product_customer)
builder.add_node("user_feedback", user_feedback)
builder.add_node("map_product", map_product_to_db)
builder.add_node("retrieve_document", retrieve_document)
builder.add_node("human_feedback", human_feedback)
builder.add_node("answer_question", answer_question)

# Logic
builder.add_edge(START, "retrieve_product_customer")
builder.add_conditional_edges("retrieve_product_customer", tools_condition)
builder.add_edge("user_feedback", END)
builder.add_edge("map_product", "retrieve_document")
builder.add_edge("retrieve_document", "human_feedback")
builder.add_edge("human_feedback", "answer_question")
builder.add_edge("answer_question", END)

memory = MemorySaver()

# Add
graph = builder.compile(checkpointer=memory, interrupt_before=["human_feedback"])

In [None]:
graph.get_graph().print_ascii()

In [None]:
initial_input = {"input" : "I got a question on my savings account, what are the rules for a shared account?"}

thread  = {"configurable": {"thread_id": "25"}}
for event in graph.stream(initial_input, thread, stream_mode="values"):
    print(event)

print(graph.get_state(thread))

user_input = input(f"Please raise your question on {graph.get_state(thread).values['product']}: ")

graph.update_state(thread, {"question": user_input}, as_node="human_feedback")

graph.get_state(thread).next

for event in graph.stream(None, thread, stream_mode="values"):
    print(event)

In [None]:
graph.get_state(thread).values

Final Answer

In [None]:
print(graph.get_state(thread).values['answer'])