# Retrieval Augmented Generation with a Graph Database

This notebook shows how to use LLMs in combination with [Neo4j](https://neo4j.com/), a graph database, to perform Retrieval Augmented Generation (RAG).

### Why use RAG?

If you want to use LLMs to generate answers based on your own content or knowledge base, instead of providing large context when prompting the model, you can fetch the relevant information in a database and use this information to generate a response. 

This allows you to:
- Reduce hallucinations
- Provide relevant, up to date information to your users
- Leverage your own content/knowledge base

### Why use a graph database?

If you have data where relationships between data points are important and you might want to leverage that, then it might be worth considering graph databases instead of traditional relational databases.

Graph databases are good to address the following:
- Navigating deep hierarchies
- Finding hidden connections between items
- Discovering relationships between items

### Use cases 

Graph databases are particularly relevant for recommendation systems, network relationships or analysing correlation between data points.  

Example use cases for RAG with graph databases include:
- Recommendation chatbot
- AI-augmented CRM 
- Tool to analyse customer behavior with natural language

Depending on your use case, you can assess whether using a graph database makes sense. 

In this notebook, we will build a **product recommendation chatbot**, with a graph database that contains Amazon products data.


## Setup

We will start by installing and importing the relevant libraries.  

Make sure you have your OpenAI account set up and you have your OpenAI API key handy. 

In [None]:
import os
import json 
import pandas as pd

In [None]:
# Optional: run to load environment variables from a .env file.
# This is not required if you have exported your env variables in another way or if you set it manually
from dotenv import load_dotenv
load_dotenv()

# Set the OpenAI API key env variable manually
# os.environ["OPENAI_API_KEY"] = "<your_api_key>"

# print(os.environ["OPENAI_API_KEY"])

## Dataset

We will use a dataset that was created from a relational database and converted to a json format, creating relationships between entities with the completions API.

We will then load this data into the graph db to be able to query it.

### Loading dataset

In [None]:
# Loading a json dataset from a file
file_path = 'amazon_product_kg.json'

with open(file_path, 'r') as file:
    jsonData = json.load(file)

In [None]:
df =  pd.read_json(file_path)
df.head()

### Connecting to db

#### Start Neo4j with required plugins
The APOC plugin is **essential** for LangChain Neo4jGraph to work properly.

```bash
docker run -d \
  --name neo4j-rag \
  -p 7474:7474 \
  -p 7687:7687 \
  -e NEO4J_AUTH=neo4j/your_password_here \
  -e NEO4J_PLUGINS='["apoc", "graph-data-science"]' \
  neo4j:5.15
```

#### Verify the setup
- **HTTP interface**: http://localhost:7474 (Neo4j Browser)
- **Bolt connection**: bolt://localhost:7687 (for Python connections)

#### Default credentials
- **Username**: `neo4j`
- **Password**: `your_password_here` (change this to something secure)

#### Management commands
```bash
# Stop the container
docker stop neo4j-rag
# Remove the container
docker rm neo4j-rag
```

#### Explore database
If you want to explore the database, you can download Neo4j Desktop (https://neo4j.com/download/) and add your local address as a new connection.

![Neo4j Connection](neo4j_add_connection.png)


In [None]:
# DB credentials
url = "bolt://localhost:7687"
username ="neo4j"
password = "your_password_here"

In [None]:
from langchain.graphs import Neo4jGraph

graph = Neo4jGraph(
    url=url, 
    username=username, 
    password=password
)

### Importing data

In [None]:
def sanitize(text):
    text = str(text).replace("'","").replace('"','').replace('{','').replace('}', '')
    return text

# Loop through each JSON object and add them to the db
i = 1
for obj in jsonData:
    print(f"{i}. {obj['product_id']} -{obj['relationship']}-> {obj['entity_value']}")
    i+=1
    query = f'''
        MERGE (product:Product {{id: {obj['product_id']}}})
        ON CREATE SET product.name = "{sanitize(obj['product'])}", 
                       product.title = "{sanitize(obj['TITLE'])}", 
                       product.bullet_points = "{sanitize(obj['BULLET_POINTS'])}", 
                       product.size = {sanitize(obj['PRODUCT_LENGTH'])}

        MERGE (entity:{obj['entity_type']} {{value: "{sanitize(obj['entity_value'])}"}})

        MERGE (product)-[:{obj['relationship']}]->(entity)
        '''
    graph.query(query)

## Querying the database

### Creating vector indexes

In order to efficiently search our database for terms closely related to user queries, we need to use embeddings. To do this, we will create vector indexes on each type of property.

We will be using the OpenAIEmbeddings Langchain utility. It's important to note that Langchain adds a pre-processing step, so the embeddings will slightly differ from those generated directly with the OpenAI embeddings API.

In [None]:
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings_model = "text-embedding-3-small"

In [None]:
vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(model=embeddings_model),
    url=url,
    username=username,
    password=password,
    index_name='products',
    node_label="Product",
    text_node_properties=['name', 'title'],
    embedding_node_property='embedding',
)

In [None]:
def embed_entities(entity_type):
    vector_index = Neo4jVector.from_existing_graph(
        OpenAIEmbeddings(model=embeddings_model),
        url=url,
        username=username,
        password=password,
        index_name=entity_type,
        node_label=entity_type,
        text_node_properties=['value'],
        embedding_node_property='embedding',
    )
    
entities_list = df['entity_type'].unique()

for t in entities_list:
    embed_entities(t)

Useful cypher query
```
SHOW INDEXES;
```

### Querying the database directly

Using `GraphCypherQAChain`, we can generate queries against the database using Natural Language.

In [None]:
from langchain.chains import GraphCypherQAChain
from langchain_openai import ChatOpenAI

chain = GraphCypherQAChain.from_llm(
    ChatOpenAI(temperature=0), graph=graph, verbose=True, allow_dangerous_requests=True
)

In [None]:
chain.run("""
Help me find curtains
""")

### Extracting entities from the prompt

However, there is little added value here compared to just writing the Cypher queries ourselves, and it is prone to error.

Indeed, asking an LLM to generate a Cypher query directly might result in the wrong parameters being used, whether it's the entity type or the relationship type, as is the case above.

We will instead use LLMs to decide what to search for, and then generate the corresponding Cypher queries using templates.

For this purpose, we will instruct our model to find relevant entities in the user prompt that can be used to query our database.

In [None]:
entity_types = {
    "product": "Item detailed type, for example 'high waist pants', 'outdoor plant pot', 'chef kitchen knife'",
    "category": "Item category, for example 'home decoration', 'women clothing', 'office supply'",
    "characteristic": "if present, item characteristics, for example 'waterproof', 'adhesive', 'easy to use'",
    "measurement": "if present, dimensions of the item", 
    "brand": "if present, brand of the item",
    "color": "if present, color of the item",
    "age_group": "target age group for the product, one of 'babies', 'children', 'teenagers', 'adults'. If suitable for multiple age groups, pick the oldest (latter in the list)."
}

relation_types = {
    "hasCategory": "item is of this category",
    "hasCharacteristic": "item has this characteristic",
    "hasMeasurement": "item is of this measurement",
    "hasBrand": "item is of this brand",
    "hasColor": "item is of this color", 
    "isFor": "item is for this age_group"
 }

entity_relationship_match = {
    "category": "hasCategory",
    "characteristic": "hasCharacteristic",
    "measurement": "hasMeasurement", 
    "brand": "hasBrand",
    "color": "hasColor",
    "age_group": "isFor"
}

In [None]:
system_prompt = f'''
    You are a helpful agent designed to fetch information from a graph database. 
    
    The graph database links products to the following entity types:
    {json.dumps(entity_types)}
    
    Each link has one of the following relationships:
    {json.dumps(relation_types)}

    Depending on the user prompt, determine if it possible to answer with the graph database.
        
    The graph database can match products with multiple relationships to several entities.
    
    Example user input:
    "Which blue clothing items are suitable for adults?"
    
    There are three relationships to analyse:
    1. The mention of the blue color means we will search for a color similar to "blue"
    2. The mention of the clothing items means we will search for a category similar to "clothing"
    3. The mention of adults means we will search for an age_group similar to "adults"
    
    
    Return a json object following the following rules:
    For each relationship to analyse, add a key value pair with the key being an exact match for one of the entity types provided, and the value being the value relevant to the user query.
    
    For the example provided, the expected output would be:
    {{
        "color": "blue",
        "category": "clothing",
        "age_group": "adults"
    }}
    
    If there are no relevant entities in the user prompt, return an empty json object.
'''

print(system_prompt)

In [None]:
from openai import OpenAI
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

# Define the entities to look for
def define_query(prompt, model="gpt-4o"):
    completion = client.chat.completions.create(
        model=model,
        temperature=0,
        response_format= {
            "type": "json_object"
        },
    messages=[
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": prompt
        }
        ]
    )
    return completion.choices[0].message.content

In [None]:
example_queries = [
    "Which pink items are suitable for children?",
    "Help me find gardening gear that is waterproof",
    "I'm looking for a bench with dimensions 100x50 for my living room"
]

for q in example_queries:
    print(f"Q: '{q}'\n{define_query(q)}\n")


### Generating queries

Now that we know what to look for, we can generate the corresponding Cypher queries to query our database. 

However, the entities extracted might not be an exact match with the data we have, so we will use the GDS cosine similarity function to return products that have relationships with entities similar to what the user is asking.

In [None]:
def create_embedding(text):
    result = client.embeddings.create(model=embeddings_model, input=text)
    return result.data[0].embedding

In [None]:
# The threshold defines how closely related words should be. Adjust the threshold to return more or less results
def create_query(text, threshold=0.6):
    query_data = json.loads(text)
    # Creating embeddings
    embeddings_data = []
    for key, val in query_data.items():
        if key != 'product':
            embeddings_data.append(f"${key}Embedding AS {key}Embedding")
    query = "WITH " + ",\n".join(e for e in embeddings_data)
    # Matching products to each entity
    query += "\nMATCH (p:Product)\nMATCH "
    match_data = []
    for key, val in query_data.items():
        if key != 'product':
            relationship = entity_relationship_match[key]
            match_data.append(f"(p)-[:{relationship}]->({key}Var:{key})")
    query += ",\n".join(e for e in match_data)
    similarity_data = []
    for key, val in query_data.items():
        if key != 'product':
            similarity_data.append(f"gds.similarity.cosine({key}Var.embedding, ${key}Embedding) > {threshold}")
    query += "\nWHERE "
    query += " AND ".join(e for e in similarity_data)
    query += "\nRETURN p"
    return query

In [None]:
def query_graph(response):
    embeddingsParams = {}
    query = create_query(response)
    query_data = json.loads(response)
    for key, val in query_data.items():
        embeddingsParams[f"{key}Embedding"] = create_embedding(val)
    result = graph.query(query, params=embeddingsParams)
    return result

In [None]:
example_response = '''{
    "category": "clothes"
}'''

result = query_graph(example_response)

In [None]:
# Result
print(f"Found {len(result)} matching product(s):\n")
for r in result:
    print(f"{r['p']['name']} ({r['p']['id']})")

### Finding similar items 

We can then leverage the graph db to find similar products based on common characteristics.

This is where the use of a graph db really comes into play.

For example, we can look for products that are the same category and have another characteristic in common, or find products that have relationships to the same entities. 

This criteria is arbitrary and completely depends on what is the most relevant in relation to your use case.

In [None]:
# Adjust the relationships_threshold to return products that have more or less relationships in common
def query_similar_items(product_id, relationships_threshold = 3):
    
    similar_items = []
        
    # Fetching items in the same category with at least 1 other entity in common
    query_category = '''
            MATCH (p:Product {id: $product_id})-[:hasCategory]->(c:category)
            MATCH (p)-->(entity)
            WHERE NOT entity:category
            MATCH (n:Product)-[:hasCategory]->(c)
            MATCH (n)-->(commonEntity)
            WHERE commonEntity = entity AND p.id <> n.id
            RETURN DISTINCT n;
        '''
    

    result_category = graph.query(query_category, params={"product_id": int(product_id)})
    #print(f"{len(result_category)} similar items of the same category were found.")
          
    # Fetching items with at least n (= relationships_threshold) entities in common
    query_common_entities = '''
        MATCH (p:Product {id: $product_id})-->(entity),
            (n:Product)-->(entity)
            WHERE p.id <> n.id
            WITH n, COUNT(DISTINCT entity) AS commonEntities
            WHERE commonEntities >= $threshold
            RETURN n;
        '''
    result_common_entities = graph.query(query_common_entities, params={"product_id": int(product_id), "threshold": relationships_threshold})
    #print(f"{len(result_common_entities)} items with at least {relationships_threshold} things in common were found.")

    for i in result_category:
        similar_items.append({
            "id": i['n']['id'],
            "name": i['n']['name']
        })
            
    for i in result_common_entities:
        result_id = i['n']['id']
        if not any(item['id'] == result_id for item in similar_items):
            similar_items.append({
                "id": result_id,
                "name": i['n']['name']
            })
    return similar_items

In [None]:
product_ids = ['1519827', '2763742']

for product_id in product_ids:
    print(f"Similar items for product #{product_id}:\n")
    result = query_similar_items(product_id)
    print("\n")
    for r in result:
        print(f"{r['name']} ({r['id']})")
    print("\n\n")



## Final result

Now that we have all the pieces working, we will stitch everything together. 

We can also add a fallback option to do a product name/title similarity search if we can't find relevant entities in the user prompt.

In [None]:
def query_db(params):
    matches = []
    # Querying the db
    result = query_graph(params)
    for r in result:
        product_id = r['p']['id']
        matches.append({
            "id": product_id,
            "name":r['p']['name']
        })
    return matches    

In [None]:
def similarity_search(prompt, k=10):
    matches = []
    embedding = create_embedding(prompt)
    query = """
    CALL db.index.vector.queryNodes('products', $k, $embedding)
    YIELD node AS p, score
    RETURN p.id AS id, p.name AS name, score
    ORDER BY score DESC
    """
    result = graph.query(query, params={'embedding': embedding, 'k': k})
    for r in result:
        matches.append({
            "id": r["id"],
            "name": r["name"],
            "score": r["score"]
        })
    return matches


In [None]:
prompt_similarity = "I'm looking for nice curtains"
print(similarity_search(prompt_similarity))

### Using together with an OpenAI agent


In [None]:
from agents import Agent, Runner, function_tool
from typing import List, Dict, Any
from pydantic import BaseModel

#### Tool Functions

First, we'll create proper tool functions using the `@function_tool` decorator:

In [None]:
class ProductMatch(BaseModel):
    id: int
    name: str

class SimilarityMatch(BaseModel):
    id: int
    name: str
    score: float

@function_tool
def search_products_by_entities(search_query: str) -> List[ProductMatch]:
    """
    Search for products based on extracted entities (color, category, etc.).
    
    Args:
        search_query: The user's search query to extract entities from
    
    Returns:
        List of matching products with id and name
    """
    try:
        # Use existing entity extraction logic
        params = define_query(search_query)
        
        # If no entities found, return empty list
        if not params or params.strip() == "{}":
            return []
            
        # Query the database using existing logic
        result = query_graph(params)
        
        matches = []
        for r in result:
            matches.append(ProductMatch(
                id=r['p']['id'],
                name=r['p']['name']
            ))
        
        return matches
    except Exception as e:
        print(f"Error in search_products_by_entities: {str(e)}")
        return []

@function_tool  
def search_products_by_similarity(search_query: str, max_results: int = 10) -> List[SimilarityMatch]:
    """
    Perform similarity search against product names and titles.
    
    Args:
        search_query: The user's search query
        max_results: Maximum number of results to return
    
    Returns:
        List of similar products with id, name, and similarity score
    """
    try:
        matches = []
        embedding = create_embedding(search_query)
        query = """
        CALL db.index.vector.queryNodes('products', $k, $embedding)
        YIELD node AS p, score
        RETURN p.id AS id, p.name AS name, score
        ORDER BY score DESC
        """
        result = graph.query(query, params={'embedding': embedding, 'k': max_results})
        
        for r in result:
            matches.append(SimilarityMatch(
                id=r["id"],
                name=r["name"],
                score=r["score"]
            ))
        
        return matches
    except Exception as e:
        print(f"Error in search_products_by_similarity: {str(e)}")
        return []

@function_tool
def find_similar_products(product_id: int, relationships_threshold: int = 3) -> List[ProductMatch]:
    """
    Find products similar to a given product based on graph relationships.
    
    Args:
        product_id: ID of the product to find similarities for
        relationships_threshold: Minimum number of shared entities for similarity
    
    Returns:
        List of similar products with id and name
    """
    try:
        similar_items = query_similar_items(str(product_id), relationships_threshold)
        
        matches = []
        for item in similar_items:
            matches.append(ProductMatch(
                id=item['id'],
                name=item['name']
            ))
        
        return matches
    except Exception as e:
        print(f"Error in find_similar_products: {str(e)}")
        return []

#### Product Recommendation Agent

Now we'll create the agent with clear instructions and tool integration:

In [None]:
def create_product_recommendation_agent() -> Agent:
    """
    Create a product recommendation agent with graph database tools.
    
    Returns:
        Agent configured with product search and recommendation tools
    """
    
    instructions = f"""
    You are a product recommendation agent that helps users find products from a graph database.
    
    Your workflow:
    1. First, use search_products_by_entities to extract entities (color, category, brand, etc.) 
       from the user's query and find matching products
    2. If that returns no results, use search_products_by_similarity to find products based on 
       semantic similarity to the user's query
    3. If you find products, optionally use find_similar_products to suggest additional 
       related items based on graph relationships
    4. Present results clearly with product names and IDs
    
    Available entity types for searching:
    {json.dumps(entity_types, indent=2)}
    
    Search Strategy:
    - Always try entity-based search first as it's more precise
    - Fall back to similarity search if no entities are found or no results returned
    - Use similar products feature to enhance recommendations
    - Present actual product names and IDs from the database results
    - If no products are found, suggest the user provide more specific details
    
    Format your responses clearly:
    - List the number of products found
    - Show product names with their IDs in parentheses
    - If suggesting similar products, clearly separate them from main results
    
    Be helpful and accurate - only return products that actually exist in the database.
    """
    
    recommendation_agent = Agent(
        name="product_recommendation_agent",
        instructions=instructions,
        tools=[
            search_products_by_entities,
            search_products_by_similarity, 
            find_similar_products
        ]
    )
    
    return recommendation_agent

#### Agent Execution

Finally, we'll create a wrapper function for easy agent interaction:

In [None]:
async def recommend_products(user_query: str) -> str:
    """
    Get product recommendations using the OpenAI agent.
    
    Args:
        user_query: The user's product search query
        
    Returns:
        Agent's response with product recommendations
    """
    agent = create_product_recommendation_agent()
    
    try:
        result = await Runner.run(agent, user_query)
        return result.final_output
    except Exception as e:
        return f"❌ Error getting product recommendations: {str(e)}"

#### Testing the OpenAI Agent

Let's test the OpenAI agent with the same queries we used for the Langchain agent:

In [None]:
test_queries = [
    "I'm searching for pink shirts",
    "Can you help me find toys for my niece, she's 8",
    "I'm looking for nice curtains"
]

print("🤖 OpenAI Agent Product Recommendations")
print("=" * 60)

for i, query in enumerate(test_queries, 1):
    print(f"\n[{i}/{len(test_queries)}] Query: {query}")
    print("-" * 50)
    
    try:
        result = await recommend_products(query)
        print(result)
    except Exception as e:
        print(f"❌ Error: {str(e)}")
    
    if i < len(test_queries):
        print("\n" + "=" * 60)

## Tracing
The agents are making multiple calls under the hood that's not seen in the code. All this can be "traced" in the OpenAI log console: https://platform.openai.com/logs?api=traces

## Conclusion

### User experience

When the primary objective is to extract specific information from our database, Large Language Models (LLMs) can significantly enhance our querying capabilities.

However, it's crucial to base much of this process on robust code logic to ensure a foolproof user experience.

For crafting a genuinely conversational chatbot, further exploration in prompt engineering is necessary, possibly incorporating few-shot examples. This approach helps mitigate the risk of generating inaccurate or misleading information and ensures more precise responses.

Ultimately, the design choice depends on the desired user experience. For instance, if the aim is to create a visual recommendation system, the importance of a conversational interface is less relevant.

### Working with a knowledge graph 

Retrieving content from a knowledge graph adds complexity but can be useful if you want to leverage connections between items. 

The querying part of this notebook would work on a relational database as well, the knowledge graph comes in handy when we want to couple the results with similar items that the graph is surfacing. 

Considering the added complexity, make sure using a knowledge graph is the best option for your use case.
If it is the case, feel free to refine what this cookbook presents to match your needs and perform even better!