
# Agentic GraphRAG powered by Memgraph 3.0 

This demo highlights Agentic GraphRAG, a system that harnesses Memgraph 3.0 and its built-in tools—including vector search, BFS, PageRank, and schema management—to power AI-driven graph applications.

The demo classifies user queries, generates Cypher statements, and executes them on Memgraph. Question classification, tool selection, and query parameterization are dynamically handled by an agent that interacts with the OpenAI API.



# Prerequisites  

In order to try this demo, you first need to start Memgraph. You should start Memgraph with the schema info enabled. Here is the command you can use to start Memgraph: 

```
docker run -d --name memgraph_graphRAG -p 7687:7687 -p 7444:7444 memgraph/memgraph-mage:3.0-memgraph-3.0 --log-level=TRACE --also-log-to-stderr --schema-info-enabled=True 
```

You should also install the dependencies needed for this demo: 

In [1]:
# Install dependencies from requirements.txt
%pip install -r ../requirements.txt


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


After the dependencies have been installed, the next step is to define a `.env` file and pass in the `OPENAI_API_KEY` that will hold the key for OpenAPI. 


## GraphRAG demo

First, we import the necessary libraries and modules. We use the OpenAI API for LLM, Sentence Transformers for vector embeddings, and the Neo4j client for connecting to Memgraph. The rest of the libraries are utilities for different smaller subtasks that will be discussed a bit later.

In [21]:
from openai import OpenAI
from sentence_transformers import SentenceTransformer
import neo4j
from pydantic import BaseModel
from typing import Dict
from dotenv import load_dotenv
from typing import Dict, List, Any
import os
import json
import logging
import tiktoken

###  Defining the model 

The next step is to define the model we want to use for the LLM API. Feel free to change this to any model you prefer. In this case, we are using GPT-4o-2024-08-06. We also need to initialize logging to track the progress of the agentic GraphRAG pipeline.

In [3]:

# Initialize logging
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)

# Predefined model used.
MODEL = {
    "name": "gpt-4o-2024-08-06",
    "context_window": 128000
}



### Tool response 

First we define a class `ToolResponse` to represent the response from the tools in the pipeline. This class contains a `status` flag and the `results` of the tool execution.
If the status is `True`, the tool execution was successful, and the `results` contain the output. If the status is `False`, the tool execution failed, and results are not valid. 


In [4]:
# Response format
class ToolResponse():
    def __init__(self, status=False, results=""):
        self.status = status
        self.results = results
    
    def __str__(self):
        return f"Status: {self.status}, Results: {self.results}"

    def set_status(self, status: bool):
        self.status = status
        return self  

    def set_results(self, results: str):
        self.results = results
        return self  

### Structured outputs

LLMs can provide bad responses, hallucinations or unnecessary context in the answer. Structured outputs help you get predictable responses from an LLM. To get a structured output from the LLM calls, we need to define the output schema. Each LLM can have a different output schema, depending on what the agent is set to do. 
Read more about structured outputs on the [OpenAI docs](https://platform.openai.com/docs/guides/structured-outputs). 

Each of the classes below defines the Pydantic model of a response that is expected from the LLM. If we provide the appropriate prompt and model, the LLM should make an output in the type and format we want. 

For example `QuestionType` defines the `type` and `explanation` in the string format, LLM will extract the values based on the prompt description, question and the pydantic model. 


In [None]:


# Agent generation of a Cypher question
class CypherQuery(BaseModel):
    query: str

# Agent response for tool selection
class ToolSelection(BaseModel):
    first_pick: str
    second_pick: str

# Agent generation for number of similar nodes and number of hops
class StructureQuestionData(BaseModel):
    number_of_similar_nodes: int
    number_of_hops: int

# Agent generation for number of nodes in the PageRank
class PageRankNodes(BaseModel):
    number_of_nodes: int


# Agent response to the user question
class QuestionType(BaseModel):
    type: str
    explanation: str

# Community summary generation
class Community(BaseModel):
    summary: str    




### Question classification 

One of the responses the LLMs need to give in the structured format is the `QuestionType`. The question type defines what **set of tools** can be applied to answer a specific question. 

Question types are defined as **Retrieval**, **Structure**, **Global** and **Database**. 

The function `classify_the_question` takes the user's question and classifies it into one of these types. 

The example below highlights the decision making of LLMs based on the question and the output in the structured format. 


In [6]:
def classify_the_question(openai_client, user_question: str) -> Dict:

    prompt = f"""
    Classify the following user question into query type

    Query Types:
    - Retrieval
    - Structure 
    - Global
    - Database

    Each type of question has different characteristics.
    - Retrieval: Direct Lookups, specific and well-defined. The query seeks information about specific entities (nodes or relationships). 
    - Structure: Exploratory, the query seeks information about the structure of the graph, close relationships between entities, or properties of nodes.
    - Global: The query seeks context about the entire graph, community, such as the most important node or global trends in graph. 
    - Database: The query seeks statistical information about the database, such as index information, node count, or relationship count, config etc.

    Example of a questions for each type:
    - Retrieval: How old is a person with the name "John"? 
    - Structure: Does John have a job? Is John a friend of Mary? Are there any people who are friends with John?
    - Globals: What is the most important node in the graph? 
    - Database: What indexes does Memgraph have?

    In the explanation, provide a brief description of the type of question, and why you classified it as such. 

    The question is in <Question> </Question> format.
    """


    user_question = f"<Question>{user_question}</Question>"
    completion = openai_client.beta.chat.completions.parse(
        model=MODEL["name"],
        messages=[
            {"role": "developer", "content": prompt},
            {"role": "user", "content": user_question},
        ],
        response_format=QuestionType,
    )

    return completion.choices[0].message.parsed



### Tool selection pipe 

Once the question type is selected, we can apply a specific tool to that particular question type. Different questions will be presented with different tools for answering the question. Ideally, this prompt would be fully dynamic in a sense where all Memgraph tools will be available to the agent context, so he can pick freely a tool he sees fit for any question. In this case, we have a hardcoded list of possible options.

If you take a look at the `response_format` and the `prompt_developer` you will notice that LLM will pick the two options for solving the problem, the first tool and the second tool as a backup. 

In [None]:

def tool_selection_pipe(openai_client, user_question, question_type) -> Dict:
    question = f"<Question>{user_question}</Question>"
    question_type = f"<Type>{question_type}</Type>"
    prompt_developer = f"""

    Based on the question type and the user's question, select the most appropriate tool option and second option as a backup to answer the question:

    Retrieval - direct lookups, specific and well-defined. The query seeks information about specific entities (nodes or relationships).
    Options: 
        - Cypher: A tool that generates a Cypher query based on the user's question and the database schema.
        - Vector Relevance Expansion: A tool that finds the most similar nodes based on the user's question and the database schema.
    Structure - exploratory, the query seeks information about the structure of the graph, close relationships between entities, or properties of nodes.
    Options:
        - Cypher: A tool that generates a Cypher query based on the user's question and the database schema.
        - Vector Relevance Expansion: A tool that finds the most similar nodes based on the user's question and the database schema.
    Global - the query seeks context about the entire graph, community, such as the most important node or global trends in graph.
    Options:
        - PageRank: A tool that provides PageRank information about the graph and its nodes, it can help with identifying the most important nodes.
        - Community: A tool that provides communities information about the graph, it contains the summary of the community, and can help with global insights.
    Database - the query seeks statistical information about the database, such as index information, node count, or relationship count, config etc.
    Options:
        - Schema: A tool that provides schema information about the dataset and datatypes.
        - Config: A tool that provides configuration information about the database.
        - Cypher: A tool that generates a Cypher query based on the user's question and the database schema.

    The question is in <Question> </Question> format, and the type of the question is <Type> </Type>.

    """
    messages = [
        {"role": "developer", "content": prompt_developer},
        {"role": "user", "content": question + question_type},

    ]
    completion = openai_client.beta.chat.completions.parse(
        model=MODEL["name"],
        messages=messages,
        response_format=ToolSelection,
    )

    return completion.choices[0].message.parsed

## Tools in the GraphRAG demo

### Schema info

The schema info tool is responsible for getting the graph schema information from the database. Providing a brief summary of the graph schema can help the user understand the structure of the graph and formulate more specific queries. That includes the node, relationship and property types.

The function `get_schema_string` retrieves the fresh schema information from the database and formats it to a human-readable string.  This function [utilizes Memgraph's](https://memgraph.com/docs/querying/schema) `SHOW SCHEMA INFO` query.

In [8]:

def get_schema_string(db_client) -> str:
    
    with db_client.session() as session:
        schema = session.run("SHOW SCHEMA INFO")
        schema_info = json.loads(schema.single().value())
        nodes = schema_info["nodes"]
        edges = schema_info["edges"]
        node_indexes = schema_info["node_indexes"]
        edge_indexes = schema_info["edge_indexes"]

        schema_str = "Nodes:\n"
        for node in nodes:
            properties = ", ".join(
                f"{prop['key']}: {', '.join(t['type'] for t in prop['types'])}"
                for prop in node["properties"]
            )
            schema_str += f"Labels: {node['labels']} | Properties: {properties}\n"

        schema_str += "\nEdges:\n"
        for edge in edges:
            properties = ", ".join(
                f"{prop['key']}: {', '.join(t['type'] for t in prop['types'])}"
                for prop in edge["properties"]
            )
            schema_str += f"Type: {edge['type']} | Start Node Labels: {edge['start_node_labels']} | End Node Labels: {edge['end_node_labels']} | Properties: {properties}\n"

        schema_str += "\nNode Indexes:\n"
        for index in node_indexes:
            schema_str += (
                f"Labels: {index['labels']} | Properties: {index['properties']}\n"
            )

        schema_str += "\nEdge Indexes:\n"
        for index in edge_indexes:
            schema_str += f"Type: {index['type']} | Properties: {index['properties']}\n"

        return schema_str

def schema_tool(db_client) -> ToolResponse:
    return ToolResponse(True, get_schema_string(db_client))



### Text to Cypher 

The function `text_to_Cypher` translates a natural language question into a Cypher query using the LLM generation.
It leverages the schema info to generate accurate queries and includes error correction and retry logic.

If the query is failing to execute, the tool will try to self-recover based on the database error message or expanded prompt. 


The `generate_cypher_query` function will call LLM to provide a cypher query in structured format based on the given prompt messages. 


In [9]:

def text_to_Cypher(db_client, openai_client, user_question) -> Dict:
    logger.info("Running text_to_cypher tool")

    schema = get_schema_string(db_client)
    prompt_user = f"""

    User Question: "{user_question}"
    Schema: {schema}

    Based on schema and question, generate a Cypher query that directly corresponds to the user's intent.
    """

    prompt_developer = f"""
    Your task is to directly translate natural language
    inquiry into precise and executable Cypher query for Memgraph database.
    You will utilize a provided database schema to understand the structure,
    nodes and relationships within the Memgraph database.

    Rules:
    - Use provided node and relationship labels and property names from the
    schema which describes the database's structure. Upon receiving a user question, synthesize the
    schema to craft a precise Cypher query that directly corresponds to
    the user's intent.
    - Generate valid executable Cypher queries on top of Memgraph database.
    - Use Memgraph MAGE procedures instead of Neo4j APOC procedures.

    With all the above information and instructions, generate Cypher query
    for the user question.
    """

    encoding = tiktoken.get_encoding("cl100k_base")
    token_count_user = len(encoding.encode(prompt_user))
    token_count_developer = len(encoding.encode(prompt_developer))
    token_count = token_count_user + token_count_developer
    logger.info(f"Token count on prompt : {token_count}")

    prompt_chain = [
            {"role": "developer", "content": prompt_developer},
            {"role": "user", "content": prompt_user},
    ]

    tool_response = ToolResponse()

    query = ""
    if token_count <= MODEL["context_window"]:
        query = generate_cypher_query(openai_client, prompt_chain)
    else:
        return tool_response.set_status(False).set_results("Token count exceeded the limit.")

    logger.info("### Cypher Query:")
    logger.info(query)
    
    res = []
    with db_client.session() as session:
        for _ in range(3):  # Try correction process up to 3 times
            try:
                results = session.run(query)
                if not results.peek():
                    raise ValueError(
                        "The query did not return any results. There is a possible issue with the query "
                        "labels and parameters, if you are matching strings consider matching them in the case-insensitive way."
                    )
                for record in results:
                    res.append(record)

                return tool_response.set_status(True).set_results(res)

            except (ValueError, Exception) as e:
                error_type = "ValueError" if isinstance(e, ValueError) else "Error"
                logger.error(f"{error_type} in running the query:")
                logger.error(e)
                error_message = str(e)

                prompt_correction = f"""
                The following Cypher query generated a {error_type}:
                Query: {query}
                Error: {error_message}
                Question: {user_question}

                Please correct the Cypher query based on the error, schema and question.
                """
                prompt_chain.append({"role": "assistant", "content": query})
                prompt_chain.append({"role": "developer", "content": prompt_correction})

                query = generate_cypher_query(openai_client, prompt_chain)
                logger.info("### Corrected Cypher Query:")
                logger.info(query)

        return tool_response.set_status(False).set_results("Error in running the query.")

def generate_cypher_query(openai_client, prompt_messages):
    completion = openai_client.beta.chat.completions.parse(
        model=MODEL["name"],
        messages=prompt_messages,
        response_format=CypherQuery,
    )
    return completion.choices[0].message.parsed.query



### Config Tool

Config tool used to retrieve the configuration information form the database. If the user has some question about Memgraph configuration this tool could help an LLM answer that question. 

In this case the `SHOW CONFIG` [query](https://memgraph.com/docs/database-management/configuration) is being used. 


In [10]:

def config_tool(db_client) -> ToolResponse:
    try:
        with db_client.session() as session:
            config = session.run("SHOW CONFIG")
            config_str = "Configurations:\n"
            for record in config:
                config_str += f"Name: {record['name']} | Default Value: {record['default_value']} | Current Value: {record['current_value']} | Description: {record['description']}\n"
            return ToolResponse(True, config_str)
    except Exception as e:
        logger.error("Error in running the Config tool query.")
        return ToolResponse(False, "Error in running the Config tool query.")

### PageRank 

PageRank tool will return the list of running a [PageRank](https://memgraph.com/docs/advanced-algorithms/available-algorithms/pagerank) algorithm on top of the graph data in Memgraph. When you think about any tool, each tool can be specifically modified to the particular question ask. 

Here is the example questions: 

"Is the Coca Cola the most important company in the dataset?"  - For this question you are probably curious what is the most important node in the dataset, and it is just one. 
"Is the Coca Cola in the 100 most important companies in the dataset?" - For this question you are interested in the 100 most important nodes in the dataset (companies). 

Now both questions utilize PageRank to get the response, but dynamically modeling and adapting the tool based on the LLM choice will call PageRank with the proper configuration. 

The LLM call `page_rank_choice` is trying to guess the number of nodes to get from page rank tool, and pass that information to the actual `page_rank_tool`. 

In [11]:


def page_rank_choice(openai_client, user_question) -> Dict:
    question = f"<Question>{user_question}</Question>"
    prompt = f"""
    Based on the provided question, try to guess how many nodes should be returned from the PageRank in the assesment. 
    The question is in <Question> </Question> format.
    """
    completion = openai_client.beta.chat.completions.parse(
        model=MODEL["name"],
        messages=[
            {"role": "developer", "content": prompt},
            {"role": "user", "content": question},
        ],
        response_format=PageRankNodes,
    )

    return completion.choices[0].message.parsed

def page_rank_tool(db_client, openai_client, user_question) -> ToolResponse:

    prompt_developer = f"""
    Based on the provided question, try to guess how many nodes should be returned from the PageRank in the assesment. 
    The question is in <Question> </Question> format.
    """

    messages = [
        {"role": "developer", "content": prompt_developer},
        {"role": "user", "content": "<Question>" + user_question + "</Question>"},
    ]

    choice = page_rank_choice(openai_client, user_question)

    logger.info("Running the PageRank tool")
    logger.info(f"Number of nodes: {choice.number_of_nodes}")

    with db_client.session() as session:
        try:
            result = session.run(f"CALL pagerank.get() YIELD node, rank RETURN node, rank LIMIT {choice.number_of_nodes};")
            result_str = ""
            for record in result:
                node = record["node"]
                properties = {k: v for k, v in node.items() if k != "embedding"}
                result_str += f"Node: {properties}, Rank: {record['rank']}\n"
            
            logger.info("Page rank successful") 
            logger.info(result_str)
            return ToolResponse(True, result_str)
        except Exception as e:
            logger.error("Error in running the PageRank tool query.")
            return ToolResponse(False, "Error in running the PageRank tool query.")

### Community summarization tool 

If the question about dataset is addressing some global trend or a bigger set of nodes and edges, community algorithms can be used to detect communities. These communities can then be further analyzed to see how they fit into the bigger picture in the graph. 

In this particular example we have used a [Louvain community detection](https://memgraph.com/docs/advanced-algorithms/available-algorithms/community_detection) to detect communities. They are being pre-computed on the start of the application in the `precompute_community_summary` function.  

The function calculates the community, after that it takes a community sub-graph, and passes the info to LLM, that computes the community summary in 5 sentences. After that the community node is created for each community that holds the summary. 

This approach is just a demonstration of how question on global trends can be handled, Leiden with the hierarchical modeling is a better choice since the communities can be explored dynamically. 


In [None]:
def precompute_community_summary(db_client, openai_client) -> Dict:

    number_of_communities = 0
    try:
        with db_client.session() as session:
            result = session.run("""
            CALL community_detection.get()
            YIELD node, community_id 
            SET node.community_id = community_id;
            """
            )
            result = session.run("""
            MATCH (n)
            RETURN count(distinct n.community_id) as community_count;
            """
            )
            for record in result:
                number_of_communities = record['community_count']
                print(f"Number of communities: {record['community_count']}")
    except Exception as e:
        logger.error("Error in running the community detection query.")
        return False; 
    
    try:
        with db_client.session() as session:
            communities = []
            for i in range(0, number_of_communities):
                community_string = ""
                community_id = 0
                result = session.run(f"""
                MATCH (start), (end) 
                WHERE start.community_id = {i} AND end.community_id = {i} AND id(start) < id(end)
                MATCH p = (start)-[*..1]-(end)
                RETURN p; 
                """)
                for record in result:
                    path = record['p']
                    for rel in path.relationships:
                        start_node = rel.start_node
                        end_node = rel.end_node
                        start_node_properties = {k: v for k, v in start_node.items() if k != 'embedding'}
                        end_node_properties = {k: v for k, v in end_node.items() if k != 'embedding'}
                        community_string += f"({start_node_properties})-[:{rel.type}]->({end_node_properties})\n"
                        community_id = i
                communities.append({"id": community_id, "data": community_string})
    except Exception as e:
        logger.error("Error in running the community detection query.")
        return False;
        
    logger.info("Total number of communities:")
    logger.info(number_of_communities)
    community_summary = []
    for community in communities:
        community_id = community['id']
        community_string = community['data']
        try:
            logging.info(f"Generating summary for community {community_id}")
            prompt = community_prompt(openai_client, community_string)
            community_summary.append({"id": community_id, "summary": prompt.summary})
        except Exception as e:
            logger.error(f"Error in generating summary for community {community_id} and community string {community_string}")
            logger.error(e)
            return False;

    try:
        with db_client.session() as session:
            for community in community_summary:
                community_id = community['id']
                summary = community['summary']
                session.run(
                    "CREATE (c:Community { id: $id, summary: $summary})",
                    summary=summary, 
                    id=community_id
                )
    except Exception as e:
        logger.error("Error in running the community detection query.")
        return False;
    
    return True




def community_prompt(openai_client, community_string) -> Dict:
    prompt = f"Summarize the following community information into 5 to 10 sentences, you will get the community string in the <Community> </Community> format"
    prompt_community= f"<Community>{community_string}</Community>"

    completion = openai_client.beta.chat.completions.parse(
        model=MODEL["name"],
        messages=[
            {"role": "developer", "content": prompt},
            {"role": "user", "content": prompt_community},
        ],
        response_format=Community,
    )
    return completion.choices[0].message.parsed

def check_if_community_summary_exists(db_client) -> bool:
    try:
        with db_client.session() as session:
            result = session.run("""
            MATCH (c:Community)
            RETURN count(c) as community_count;
            """
            )
            for record in result:
                if record['community_count'] > 0:
                    return True
    except Exception as e:
        logger.error("Error in running the community detection query.")
        logger.error(e) 
    return False



After communities are pre-computed, the community tool calls a simple query that matches all communities, returns them and provides an insight into what is happening in each community. If the dataset scale is large and there are a lot of communities, vector search can be applied on top of the community summary to fetch only the communities that are most similar to the actual question. 

In [13]:
def community_tool(db_client) -> ToolResponse:
    try:
        with db_client.session() as session:
            result = session.run("MATCH (n:Community) RETURN n.id, n.summary;")
            result_str = ""
            for record in result:
                result_str += f"Community ID: {record['n.id']}, Summary: {record['n.summary']}\n"
            return ToolResponse(True, result_str)
    except Exception as e:
        logger.error("Error in running the Community tool query.")
        return ToolResponse(False, "Error in running the Community tool query.")



### Vector search and relevance expansion 

If the question is about the graph structure, finding a pivot point (node) in the graph and expanding from that node will give the answer about the graph structure that encodes the knowledge. An example would be, "How is Coca cola related to rest of the dataset?". 

Question is quite open but it points into direction of understanding the relationships around Coca-Cola. In order to find the pivot node, we create embeddings on top of node properties, this is computed in the `compute_node_embeddings`, each node also gets the `Entity` label since we are performing vector search on all the nodes in the database. Details of the index are visible in the `index_setup`. 


In [None]:

def index_setup(db_client):
    with db_client.session() as session:
        print("Creating the vector index...")
        session.run(
            """
            CREATE VECTOR INDEX index_name ON :Entity(embedding) WITH CONFIG {"dimension": 384, "capacity": 2000, "metric": "cos","resize_coefficient": 2};
            """
        )

def compute_node_embeddings(db_client):
    model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
    with db_client.session() as session:
        # Retrieve all nodes
        result = session.run("MATCH (n) RETURN n")
        for record in result:
            node = record["n"]
            # Check if the node already has an embedding
            if "embedding" in node:
                print("Embedding already exists")
                return

            # Combine node labels and properties into a single string
            node_data = (
                " ".join(node.labels)
                + " "
                + " ".join(f"{k}: {v}" for k, v in node.items())
            )
            # Compute the embedding for the node
            node_embedding = model.encode(node_data)

            # Store the embedding back into the node
            session.run(
                f"MATCH (n) WHERE id(n) = {node.element_id} SET n.embedding = {node_embedding.tolist()}"
            )

        session.run("MATCH (n) SET n:Entity")




After setup, we can now run vector search defined in `vector_relevance_expansion` to find the most similar nodes to the actual question. Then again, there is LLM making a decision in `decide_on_structure_parameters` on how many similar nodes will be considerd and how many hops will be made to find the answer. 

Finally, we can run the Cypher query in `get_relevant_data` to get the data that could potential answer our question about data.

In [15]:
def vector_relevance_expansion(db_client, openai_client, user_question) -> Dict:

    model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
    question_embedding = model.encode(user_question)

    prompt_parameters = f"""
    You will get a question about the structure of the graph. The vector search
    will find the most similar node based on the question embedding an node
    embedding, and then return the data connected to the most similar nodes that
    are hops away. Your task is to find out how many nodes should vector search
    return and how many hops should be used to find the relevant data. If the
    question is about undefined number of node guess the intended number of
    nodes, by default consider 1. If the question is about undefined number of
    hops guess the intended number of hops, by default consider 1.
    """

    messages = [ 
        {"role": "developer", "content": prompt_parameters}, 
        {"role": "user", "content": user_question} 
        ]

    structure_parameters = decide_on_structure_parameters(openai_client, messages)

    logger.info("Structure parameters:")
    logger.info(structure_parameters)

    nodes = find_most_similar_nodes(db_client, user_question,  question_embedding, structure_parameters.number_of_similar_nodes)


    for node in nodes:
        logger.info("Most similar nodes:")
        logger.info(node)

    tool_response = ToolResponse()
    if nodes is None:
        return tool_response.set_status(False).set_results("No similar nodes found.")

    relevant_data = get_relevant_data(db_client, nodes, structure_parameters.number_of_hops)

    return tool_response.set_status(True).set_results(relevant_data)



def decide_on_structure_parameters(openai_client, messages) -> Dict:
    completion = openai_client.beta.chat.completions.parse(
        model=MODEL["name"],
        messages=messages,
        response_format=StructureQuestionData,
    )
    return completion.choices[0].message.parsed


    

def find_most_similar_nodes(db_client, user_question,  question_embedding, number_of_similar_nodes):
        
    with db_client.session() as session:
        result = session.run(
            f"CALL vector_search.search('index_name', {number_of_similar_nodes}, {question_embedding.tolist()}) YIELD * RETURN *;"
        )
        nodes_data = []
        for record in result:
            node = record["node"]
            properties = {k: v for k, v in node.items() if k != "embedding"}
            node_data = {
                "distance": record["distance"],
                "id": node.element_id,
                "labels": list(node.labels),
                "properties": properties,
            }

            nodes_data.append(node_data)
        print("All similar nodes:")
        for node in nodes_data:
            print(node)

        return nodes_data if nodes_data else None


def get_relevant_data(db_client, nodes, hops):
    paths = []
    for node in nodes:
        with db_client.session() as session:
            query = (
                f"MATCH path=((n)-[r*..{hops}]-(m)) WHERE id(n) = {node['id']} RETURN path"
            )
            result = session.run(query)
            
            for record in result:
                path_data = []
                for segment in record["path"]:

                    # Process start node without 'embedding' property
                    start_node_data = {
                        k: v for k, v in segment.start_node.items() if k != "embedding"
                    }

                    # Process relationship data
                    relationship_data = {
                        "type": segment.type,
                        "properties": segment.get("properties", {}),
                    }

                    # Process end node without 'embedding' property
                    end_node_data = {
                        k: v for k, v in segment.end_node.items() if k != "embedding"
                    }

                    # Add to path_data as a tuple (start_node, relationship, end_node)
                    path_data.append((start_node_data, relationship_data, end_node_data))

                paths.append(path_data)

    return paths

### Tool execution pipe

After each of the tools have been defined here is the simple pipeline for tool execution. Each tool returns, status and results that contain data necessary to answer the question as described in the `ToolResponse` class.

In [16]:
def tool_execution(tool: str, db_client, openai_client, user_question) -> ToolResponse:

    if tool == "Cypher":
        return text_to_Cypher(db_client, openai_client, user_question)
    elif tool == "Vector Relevance Expansion":
        return vector_relevance_expansion(db_client, openai_client, user_question)
    elif tool == "PageRank":
        return page_rank_tool(db_client, openai_client, user_question)
    elif tool == "Community":
        return community_tool(db_client)
    elif tool == "Schema":
        return  schema_tool(db_client)
    elif tool == "Config":
        return config_tool(db_client)
    else:
        return ToolResponse(False, "Tool execution failed, tool not found.")


def execute_tool(tool: str, user_question: str, db_client,  openai_client ) -> ToolResponse:
    """
    Executes the given tool based on its name.
    Returns True if successful, False otherwise.
    """
    response = None
    try:
        logger.info(f"Trying tool: {tool}")
        response = tool_execution(tool, db_client, openai_client, user_question)
        return response
    except Exception as e:
        logger.error(f"Error executing {tool}: {e}")
        return response

### Generating final response 

Once the data has been passed from any tools above, the LLM tries to generate the answer based on that question. Since tool selection and tool configuration are performed dynamically and decided by the LLM, it is not possible to be certain about what will happen on every question run, but returned data should hold the information related to the question. 


In [17]:
def generate_final_response(openai_client, results, user_question: str):
    prompt = f"""
    Using the data and the user's original question, generate a final answer:
    User Question: "{user_question}"
    Data from the database: {results}

    Try to answer the user's question using just the the provided data, and the user's question.
    
    """
    completion = openai_client.chat.completions.create(
        model=MODEL["name"],
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
    )
    return completion.choices[0].message

## Running the demo 

This is the end of the code, below are the client connections, and preprocessing of data that is described earlier. 

In [30]:
def get_openai_client():
    return OpenAI()

def get_db_client():
    return neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("", ""))

def preprocess_data(_db_client, _openai_client):
    if not check_if_community_summary_exists(_db_client):
        status = precompute_community_summary(db_client, openai_client)
        if status:
            logger.info("Community summary precomputed.")
        else:
            logger.error("Error in precomputing community summary.")
            logger.error("Community questions will fail")
    else:
        logger.info("Community summary already exists.")

    index_setup(db_client)
    compute_node_embeddings(db_client)
    return "Proccessing data completed"


Here is the main function, make sure to change the `user_question` to get response on your own question: 

In [42]:



def main(db_client, openai_client):

    # Change it to your dataset
    user_question = "What can you tell me about Coca cola?"


    db_client = get_db_client()
    openai_client = get_openai_client()

    preprocess_data(db_client, openai_client)

    
    question_type = classify_the_question(openai_client, user_question)

    logger.info("Question type:")
    logger.info(question_type)

    logger.info("Tool selection:")
    tools = tool_selection_pipe(openai_client, user_question, question_type)

    logger.info("Tools selected:")
    logger.info(tools)


    response = execute_tool(tools.first_pick, user_question, db_client, openai_client)
    if response.status:
        logger.info(f"First pick: '{tools.first_pick}' succeeded.")
    else:
        response = execute_tool(tools.second_pick, user_question, db_client, openai_client)
        if response.status:
            logger.info(f"Second pick: '{tools.second_pick}' succeeded.")
        else:
            logger.error(f"Both tools failed for question: {user_question}")


    if response.status is False:
        logger.error(f"Both tools failed for question: {user_question}")
    else:
        formated_response = {"Response": response.results}

        logger.info("Generating final response")
        final_response = generate_final_response(
            openai_client, response.results, user_question
        )
        logger.info("Final response: " + final_response.content)
        logger.info("Final response generated.")

        logger.info("Pipeline completed.")
        
      



if __name__ == "__main__":
    load_dotenv()
    
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    openai_client = get_openai_client()

    db_client = get_db_client()

    preprocess_data(db_client, openai_client)

    main(db_client, openai_client)





2025-02-21 10:26:38,614 - INFO - Community summary already exists.
2025-02-21 10:26:38,618 - INFO - Use pytorch device_name: mps
2025-02-21 10:26:38,618 - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2


Creating the vector index...
Embedded data: 
Embedding already exists


2025-02-21 10:26:40,681 - INFO - Community summary already exists.
2025-02-21 10:26:40,684 - INFO - Use pytorch device_name: mps
2025-02-21 10:26:40,684 - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2


Creating the vector index...
Embedded data: 
Embedding already exists


2025-02-21 10:27:06,584 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-02-21 10:27:06,606 - INFO - Question type:
2025-02-21 10:27:06,607 - INFO - type='Retrieval' explanation="This question is seeking specific information about 'Coca Cola'. It is a retrieval query because it directly asks for details about a particular entity (Coca Cola), which could include attributes or direct facts related to this brand in a database or knowledge graph."
2025-02-21 10:27:06,607 - INFO - Tool selection:
2025-02-21 10:27:07,673 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-02-21 10:27:07,674 - INFO - Tools selected:
2025-02-21 10:27:07,675 - INFO - first_pick='Cypher' second_pick='Vector Relevance Expansion'
2025-02-21 10:27:07,675 - INFO - Trying tool: Cypher
2025-02-21 10:27:07,675 - INFO - Running text_to_cypher tool
2025-02-21 10:27:07,731 - INFO - Token count on prompt : 41629
2025-02-21 10:27:10,122 -