# Module 4 - RAG-Chatbot 

In this module we create a chatbot using RAG ([Retrieval Augmented Generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation)) and [GraphRAG](https://graphrag.com/). 

Import our usual suspects (and some more...)

In [None]:
import pandas as pd
import os
from langchain_openai import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from neo4j import Query, GraphDatabase, RoutingControl, Result
from dotenv import load_dotenv
import gradio as gr
import time
from IPython.display import display, HTML
import warnings
from json import loads, dumps
warnings.filterwarnings('ignore')

## Get Credentials

Load env variables

In [None]:
env_file = 'ws.env'

In [None]:
if os.path.exists(env_file):
    load_dotenv(env_file, override=True)

    # Neo4j
    HOST = os.getenv('NEO4J_URI')
    USERNAME = os.getenv('NEO4J_USERNAME')
    PASSWORD = os.getenv('NEO4J_PASSWORD')
    DATABASE = os.getenv('NEO4J_DATABASE')

    # AI
    OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
    os.environ['OPENAI_API_KEY']=OPENAI_API_KEY
    LLM = os.getenv('LLM')
    EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL')
else:
    print(f"File {env_file} not found.")

## Setup Connection to Database

Setup connection to the database with the [Python Driver](https://neo4j.com/docs/python-manual/5/).

In [None]:
driver = GraphDatabase.driver(
    HOST,
    auth=(USERNAME, PASSWORD)
)

Test the connection

In [None]:
driver.execute_query(
    """
    MATCH (n) RETURN COUNT(n) as Count
    """,
    database_=DATABASE,
    routing_=RoutingControl.READ,
    result_transformer_= lambda r: r.to_df()
)

## Create RAG-application

For the the chatbot we both need an Embedding-model and LLM. Create both below:

In [None]:
embedding_model = OpenAIEmbeddings(
    model=EMBEDDINGS_MODEL,
    openai_api_key=OPENAI_API_KEY
)

In [None]:
embedding_model.model

In [None]:
llm = ChatOpenAI(temperature=0, model=LLM)

In [None]:
llm.model_name

### Retrieval Queries

To illustrate the difference between a "Regular" Vector Search and GraphRAG we create different retrieval queries.

The following function retrieves the context using a regular vector search. 

In [None]:
def get_context_vector_search(search_prompt):
    query_vector = embedding_model.embed_query(search_prompt)
    
    similarity_query = """ 
        CALL db.index.vector.queryNodes("chunk-embeddings", 3, $query_vector) YIELD node, score
        WITH node as chunk, score ORDER BY score DESC
        MATCH (d:Document)<-[:PART_OF]-(chunk)
        RETURN score, d.file_name as file_name, chunk.id as chunk_id, chunk.page as page, chunk.chunk_eng AS chunk
       """
    results = driver.execute_query(
        similarity_query,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        query_vector=query_vector,
        result_transformer_= lambda r: r.to_df()
    )
    
    results = results.to_json(orient="records")
    parsed = loads(results)
    context = dumps(parsed, indent=2)

    return context

The following function retrieves the context using a the Knowledge Graph (GraphRAG). We start with a regular vector search and can find more than that. 

In [None]:
def get_context_graphrag(search_prompt):
    query_vector = embedding_model.embed_query(search_prompt)
    
    similarity_query = """ 
        CALL db.index.vector.queryNodes("chunk-embeddings", 3, $query_vector) YIELD node, score
        WITH node as chunk, score ORDER BY score DESC
        CALL (chunk) {
            MATCH (chunk)-[r:OVERLAPPING_DEFINITIONS]-(overlapping_chunk:Chunk)
            WHERE r.overlap > 3
            RETURN collect(overlapping_chunk) AS overlapping_chunks
        }
        WITH [chunk] + overlapping_chunks AS chunks
        UNWIND chunks as chunk
        MATCH (d:Document)<-[:PART_OF]-(chunk)
        RETURN d.file_name as file_name, chunk.id as chunk_id, chunk.page as page, chunk.chunk_eng AS chunk
       """
    results_1 = driver.execute_query(
        similarity_query,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        query_vector=query_vector,
        result_transformer_= lambda r: r.to_df()
    )

    chunk_ids = list(set(results_1['chunk_id'].to_list()))

    definition_query = """    
       CALL db.index.vector.queryNodes("definition-embeddings", 5, $query_vector) YIELD node, score
            WITH node as definition, score ORDER BY score DESC
            WHERE definition.degree < 20
            WITH definition LIMIT 1
            MATCH (definition)<-[:MENTIONS]-(chunk:Chunk)
            WHERE NOT (chunk.id IN $chunk_ids)
            WITH chunk LIMIT 3
            MATCH (d:Document)<-[:PART_OF]-(chunk)
            RETURN d.file_name as file_name, chunk.id as chunk_id, chunk.page as page, chunk.chunk_eng AS chunk
    """
    results_2 = driver.execute_query(
        definition_query,
        database_=DATABASE,
        routing_=RoutingControl.READ,
        chunk_ids=chunk_ids,
        query_vector=query_vector,
        result_transformer_= lambda r: r.to_df()
    )
    results = pd.concat([results_1,results_2]).drop_duplicates()
    results = results.to_json(orient="records")
    parsed = loads(results)
    context = dumps(parsed, indent=2)
    return context

### Prompts 

Prompt for vector search which returns just the context.

In [None]:
def generate_prompt(search_prompt, context):
    prompt_template = """

    You are a chatbot on Rabobank product. Your goal is to help people with questions on product policies.  
    A user will come to you with questions on their policy. Their questions must be answered based on the relevant documents of the policy.
    Respond in English. 

    The question is the following: 
    {search_prompt}
    
    Always respond in the language in which the question was asked. So, do not respond in a different language.
    
    The context is the following: 
    {context}

    Please explain your answer as thorough as possbile based on the context above. Don't come up with anything yourself.
    
    Please end your message with listing your sources with file name and page number. 
    """
    prompt = PromptTemplate.from_template(prompt_template)
    
    theprompt = prompt.format_prompt(search_prompt=search_prompt, context=context)
    return theprompt

The prompt for GraphRAG provides context and definitions.

## Some examples to test the models

For every example there can be chosen between GraphRAG and vector search. 

In [None]:
search_prompt = 'What is meant with the Rabofoon?'

context = get_context_vector_search(search_prompt)
theprompt = generate_prompt(search_prompt, context)
llm(theprompt.to_messages()).pretty_print()

In [None]:
search_prompt = 'What is meant with the Rabofoon?'

context = get_context_graphrag(search_prompt)
theprompt = generate_prompt(search_prompt, context)
llm(theprompt.to_messages()).pretty_print()

In [None]:
search_prompt = 'Are you insured when traveling to a high-risk country?'

context = get_context_vector_search(search_prompt)
theprompt = generate_prompt(search_prompt, context)
llm(theprompt.to_messages()).pretty_print()

In [None]:
search_prompt = 'Are you insured when traveling to a high-risk country?'

context = get_context_graphrag(search_prompt)
theprompt = generate_prompt(search_prompt, context)
llm(theprompt.to_messages()).pretty_print()

In [None]:
search_prompt = 'What are the rules for a joint investment account?'

context = get_context_vector_search(search_prompt)
theprompt = generate_prompt(search_prompt, context)
llm(theprompt.to_messages()).pretty_print()

In [None]:
search_prompt = 'What are the rules for a joint investment account?'

context = get_context_graphrag(search_prompt)
theprompt = generate_prompt(search_prompt, context)
llm(theprompt.to_messages()).pretty_print()

## Gradio Chatbot that uses RAG and GraphRAG

Example code is coming from Gradio documentation: [Creating a custom chatbot with blocks](https://www.gradio.app/guides/creating-a-custom-chatbot-with-blocks#add-streaming-to-your-chatbot)

In [None]:
def user(user_message, history):
    return "", history + [[user_message, None]]

def get_answer(search_prompt, rag_method):
    if rag_method == "Vector-Search":
        context = get_context_vector_search(search_prompt)
        theprompt = generate_prompt(search_prompt, context)
    else: 
    # rag_method == "GraphRAG"
        context = get_context_graphrag(search_prompt)
        theprompt = generate_prompt(search_prompt, context)
    messages = llm(theprompt.to_messages())
    return messages.content

def bot(history, rag_method):
    bot_message = get_answer(history[-1][0], rag_method)
    history[-1][1] = ""
    for character in bot_message:
        history[-1][1] += character
        time.sleep(0.01)
        yield history

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(
        label="Chatbot with RAG", 
        avatar_images=["https://png.pngtree.com/png-vector/20220525/ourmid/pngtree-concept-of-facial-animal-avatar-chatbot-dog-chat-machine-illustration-vector-png-image_46652864.jpg","https://d-cb.jc-cdn.com/sites/crackberry.com/files/styles/larger/public/article_images/2023/08/openai-logo.jpg"]
    )
    msg = gr.Textbox(label="Message")
    rag_method = gr.Radio(["Vector-Search", "GraphRAG"], label="RAG-method:")
    clear = gr.Button("Clear")


    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, rag_method], chatbot
    )
    
    clear.click(lambda: None, None, chatbot, queue=False)

    
demo.queue()
demo.launch(share=False)

If you want to have the light-mode for the chatbot paste the following after the URL: /?__theme=light