# Product recommendation using RAG and Dgraph Database

In this notebook we will use Dgraph database to store retail products information (Amazon products data) and Language Models to reply to users asking for a **product recommendation**.

The language models are used in 3 different ways
- we use LLM text analysis capabilities to craft a Graph database query to fetch the relevant information to generate a response. 
- we use a small model to generate and store text embeddings to find products, categories, brand, characteristics etc. based on semantic similarity
- we use a LLM to generate a response to the use question based on the data retrieved in the Graph.

This is a case of Retrieval Augmented Generation (RAG), and NLP (Natural Language Processing) leveraging graph structures.

### Why Dgraph?

Dgraph is particularly suited for knowledge graph and AI applications due to several key features and capabilities:

- Graph Database Structure: Dgraph is designed as a native graph database, which means it stores data in a graph structure consisting of nodes, edges, and properties. This is inherently aligned with the way knowledge graphs represent relationships and entities, making it easier to model complex interconnections.

- Native vector support: Any node may have any number of vector predicates which are indexed using HNSW algorithm for fast similarity retrieval.

- Scalability: Dgraph is built to scale horizontally, handling large volumes of data and high query loads efficiently. This is crucial for AI applications that often require processing vast amounts of interconnected data.

- High Performance: Dgraph provides fast query execution and low latency, which are essential for real-time AI applications. Its performance optimizations, such as parallel query execution and efficient data storage, make it capable of handling demanding workloads.

- Flexible Schema: Dgraph supports flexible schema definitions, allowing for dynamic data models that can evolve over time. This is beneficial for AI applications where the data schema might need to adapt to new requirements or insights.

- Rich Querying Capabilities: Dgraph’s query language, DQL (Dgraph Query Language) is declarative, which means that queries return a response back in a similar shape to the query. DQL allows for complex graph traversals and pattern matching, which are essential for extracting insights and relationships in knowledge graphs. It also supports advanced features like recursive queries and aggregations and most importantly vector similarity search.


## Setup

We just need some python packages for Dgraph, Openai, Hugging Face and some tools we are using.

Create a file `.env` in the folder containing this python notebook with one line for your OpenAI API key
```
OPENAI_API_KEY=sk-....
```

In [None]:
# Optional script to install all the required packages
!pip3 install pydgraph
!pip3 install openai
!pip3 install sentence_transformers
!pip3 install pybars3
!pip3 install python-dotenv


In [None]:
import os
import json 
import pydgraph
from pybars import Compiler
# Activate the provider you want to use for embeddings and LLM
# from openai import OpenAI
# from mistralai.client import MistralClient

from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
load_dotenv()

assert os.getenv("OPENAI_API_KEY") is not None, "Set OPENAI_API_KEY in your .env file"

## Dataset

Dgraph supports JSON and RDF format.
In this Notebook we are using RDF. 
RDF is a powerful notation for knowledge graph. It describes information in triples of the form Subject - Predicate - Object (S-P-O).


The original dataset is in JSON format and is 2.7Mb. We have generated an RDF file with the exact same information. The RDF file is only 361 Kb! 

If you are interested, see [generateRDF](./generateRDF.ipynb) notebook.

## Loading dataset
### Connecting to Dgraph

See [Learning Environment](https://dgraph.io/docs/deploy/installation/single-host-setup/) to setup a docker image with `dgraph:standalone/latest`, or use your on-prem or cloud instance.

In [None]:
# DB credentials
if "DGRAPH_GRPC" in os.environ:
    dgraph_grpc = os.environ["DGRAPH_GRPC"]
else:
    dgraph_grpc = "localhost:9080"

# DGRAPH_ADMIN_KEY must be defined in env variables
if "cloud.dgraph" in dgraph_grpc:
    assert "DGRAPH_ADMIN_KEY" in os.environ, "DGRAPH_ADMIN_KEY must be defined"
    APIAdminKey = os.environ["DGRAPH_ADMIN_KEY"]
else:
    APIAdminKey = None

# TRANSFORMER_API_KEY must be defined in env variables
# client stub for on-prem requires grpc host:port without protocol
# client stub for cloud requires the grpc endpoint of graphql endpoint or base url of the cluster

if APIAdminKey is None:
    client_stub = pydgraph.DgraphClientStub(dgraph_grpc)
else:
    client_stub = pydgraph.DgraphClientStub.from_cloud(dgraph_grpc,APIAdminKey )     
client = pydgraph.DgraphClient(client_stub)
print(f"Connected to DGraph at {dgraph_grpc}")


### House keeping

First we clean the DB. You may want to skip this step. 

In [None]:

# Drop all data including schema from the Dgraph instance. This is a useful
# for small examples such as this since it puts Dgraph into a clean state.
confirm = input("drop schema and all data (y/n)?")
if confirm == "y":
  op = pydgraph.Operation(drop_all=True)
  client.alter(op)
  print("schema and data deleted")

### Deploying the Graph schema

In [None]:
# add predicates to Dgraph type schema
with open('data/products.schema', 'r') as file:
	dqlschema = file.read()
	op = pydgraph.Operation(schema=dqlschema)
	client.alter(op)
	print("schema updated:")
	print(dqlschema)

### Importing data

In [None]:
def mutate_rdf(nquads, client):
    ret = {}
    body = "\n".join(nquads)
    if len(nquads) > 0:
        txn = client.txn()
        try:
            res = txn.mutate(set_nquads=body)
            txn.commit()
            ret["nquads"] = len(nquads),
            ret["total_ns"]= res.latency.total_ns
        except pydgraph.errors.AbortedError as err:
            print("AbortedError %s" % i)
        except Exception as inst:
            print(inst)
        finally:
            txn.discard()

    return ret

with open('data/products.rdf') as f:
    data = f.readlines()
    mutate_rdf(data, client)

### Simple graph query
As our data is now in a graph database, we can traverse the graph, search for nodes, count relationships etc. 
To verify that we have data in the DB, let's execute a simple query to find the top 5 categories and their number of products, 

In [301]:
query = '''
{ 
    var(func:type(category)) { 
        np as count(~Product.category)
    }
    productsPerCategory(func:uid(np), orderdesc:val(np), first:5){
        category:category.Value
        number_of_products:val(np)
    }
}
'''

res = client.txn(read_only=True).query(query)
res = json.loads(res.json)
print("Top 5 categories with the most products:")
print (json.dumps(res, indent=4))

Top 5 categories with the most products:
{
    "productsPerCategory": [
        {
            "category": "home decoration",
            "number_of_products": 20
        },
        {
            "category": "books",
            "number_of_products": 17
        },
        {
            "category": "women clothing",
            "number_of_products": 17
        },
        {
            "category": "book",
            "number_of_products": 14
        },
        {
            "category": "footwear",
            "number_of_products": 12
        }
    ]
}


## Similarity search with vector embeddings

We don't want to constrain the question to only use terms present in the database. For example, the user may want "some clothes of dark color". We need to search our graph by similarity and not only by terms. 
We will use the power of Dgraph vectors and language model vector embeddings.

### Creating vector indexes
Dgraph is a Graph database with native vector support, HNSW index, and similarity search. For this use case, we will be using a Python script shared in the blog post [Add OpenAI, Mistral or open-source embeddings to your knowledge graph.](https://dgraph.io/blog/post/embeddings/) to compute and add vector embeddings to all our entities. 

For example, with an embedding on the `color` entities, we will be able to search for colors `similar_to` "dark color".

In [None]:
# Generic functions to manage embeddings
compiler = Compiler()
def clearIndex(predicate):
    print(f"remove index for {predicate}")
    schema = f"{predicate}: float32vector ."
    op = pydgraph.Operation(schema=schema)
    alter = client.alter(op)
    print(alter)
def computeIndex(predicate,index):
    print(f"create index for {predicate} {index}")
    schema = f"{predicate}: float32vector @index({index}) ."
    op = pydgraph.Operation(schema=schema)
    alter = client.alter(op)
    print(alter) 

def huggingfaceEmbeddings(model,sentences):
    embeddings = model.encode(sentences)
    return embeddings.tolist()


def computeEmbedding(predicate,data, template,provider,modelName,model,llm, dimensions):
    # data is an array of objects containing uid and other predicates
    # create an array of text
    # get the embeddings
    # produce a RDF text
    # data is a list of object having uid and other predicates used in the template

    nquad_list = []
    sentences = [template(e) for e in data]

    if "huggingface"== provider:
        embeddings = huggingfaceEmbeddings(model,sentences)
    elif "openai" == provider:
        if dimensions != None:
            openaidata = llm.embeddings.create(
            input=sentences,
            model=modelName,
            encoding_format="float",
            dimensions=dimensions)
        else:
            openaidata = llm.embeddings.create(
            input=sentences,
            model=modelName,
            encoding_format="float")    
        embeddings = [e.embedding for e in openaidata.data]
    elif "mistral" == provider:
        mistraldata= llm.embeddings(
        model=modelName,
        input=sentences)
        embeddings = [e.embedding for e in mistraldata.data]
  
    # embeddings is a list of vectors in the same order as the input data
    try:
        for i in range (0,len(data)):
            uid = data[i]['uid']
            nquad_list.append(f'<{uid}> <{predicate}> "{embeddings[i]}" .')
    # (prompt="{body[uid]}")
    except Exception as inst:
            print(embeddings)
    return nquad_list


def buildEmbeddings(embedding_def,only_missing=True):
    global client
    predicate = f"{embedding_def['entityType']}.{embedding_def['attribute']}"
    if 'disabled' in embedding_def and  embedding_def['disabled'] == True:
        print(f"Predicate {predicate} is disabled")
        return 0
    else:
        entity = embedding_def['entityType']
        config = embedding_def['config']
        provider = embedding_def['provider']
        modelName = embedding_def['model']
        dimensions = embedding_def['dimensions'] if "dimensions" in embedding_def else None
        index = embedding_def['index']

        if "huggingface" == provider:
            model = SentenceTransformer(modelName)
            llmclient = None
        else:   
            model = None
            if "openai" == provider:
                assert "OPENAI_API_KEY" in os.environ, "OPENAI_API_KEY must be defined"
                llmclient = OpenAI(
                    # This is the default and can be omitted
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
            elif "mistral" == provider:
                assert "MISTRAL_API_KEY" in os.environ, "MISTRAL_API_KEY must be defined"
                llmclient = MistralClient(
                    api_key=os.environ.get("MISTRAL_API_KEY")
                )
        total = 0
        
        template = compiler.compile(config['template'])
        # inject uid in the query
        # querypart = re.sub(r'([a-zA-Z_]+)',rf"\1:{entity}.\1",config['query'])
        querypart = config['dqlQuery']
        querypart = querypart.replace("{", "{ uid ",1)
        
        # remove index by updating DQL schema
        clearIndex(predicate)
        print(f"compute embeddings for {predicate} using  model {modelName} from {provider}")
        if only_missing == True:
            filter = f"@filter( NOT has({predicate}))"
        else:
            filter = "" 
        # Run query.
        after = ""
        while True:
            print(f"\r{total} processed", end = '')
            txn = client.txn(read_only=True)
            query = f"{{list(func: type({entity}),first:100 {after}) {filter}  {querypart} }}"
            try:
                res = txn.query(query)
                data = json.loads(res.json)
            except Exception as inst:
                print(type(inst))    # the exception type
                print(inst.args)     # arguments stored in .args
                print(inst)
                break
            finally:
                txn.discard()
            
            if len(data['list']) > 0:
                last_uid = data['list'][-1]['uid']
                after = f",after:{last_uid}"
            else: 
                break

            nquads = computeEmbedding(predicate,data['list'],template,provider,modelName,model,llmclient, dimensions)
            mutate_rdf(nquads,client)
            total += len(data['list'])
        print(f"\r{total} processed")
        computeIndex(predicate,index)
        
    return total

With the embedding logic in place we can now define the embeddings we want to compute.

The embeddings are computed using the Hugging Face Sentence Transformer model `all-MiniLM-L6-v2`. 

The `template` is a handlebars template that generates the text to be embedded from the `dqlQuery` result. 

In [None]:
embedding_config =  [
        {
            "entityType":"Product",
            "attribute":"embedding",
            "index":"hnsw",
            "provider": "huggingface",
            "model":"sentence-transformers/all-MiniLM-L6-v2",
            "config" : {
                "dqlQuery" : "{ title: Product.Title}",
                "template": "{{title}} "
            }
        },
        {
            "entityType":"age_group",
            "attribute":"embedding",
            "index":"hnsw",
            "provider": "huggingface",
            "model":"sentence-transformers/all-MiniLM-L6-v2",
            "config" : {
                "dqlQuery" : "{ value: age_group.Value}",
                "template": "{{value}} "
            }
        },
        {
            "entityType":"brand",
            "attribute":"embedding",
            "index":"hnsw",
            "provider": "huggingface",
            "model":"sentence-transformers/all-MiniLM-L6-v2",
            "config" : {
                "dqlQuery" : "{ value: brand.Value}",
                "template": "{{value}} "
            }
        },
        
        {
            "entityType":"category",
            "attribute":"embedding",
            "index":"hnsw",
            "provider": "huggingface",
            "model":"sentence-transformers/all-MiniLM-L6-v2",
            "config" : {
                "dqlQuery" : "{ value: category.Value}",
                "template": "{{value}} "
            }
        },
        {
            "entityType":"characteristic",
            "attribute":"embedding",
            "index":"hnsw",
            "provider": "huggingface",
            "model":"sentence-transformers/all-MiniLM-L6-v2",
            "config" : {
                "dqlQuery" : "{ value: characteristic.Value}",
                "template": "{{value}} "
            }
        },
        {
            "entityType":"color",
            "attribute":"embedding",
            "index":"hnsw",
            "provider": "huggingface",
            "model":"sentence-transformers/all-MiniLM-L6-v2",
            "config" : {
                "dqlQuery" : "{ value: color.Value}",
                "template": "{{value}} "
            }
        },
        {
            "entityType":"material",
            "attribute":"embedding",
            "index":"hnsw",
            "provider": "huggingface",
            "model":"sentence-transformers/all-MiniLM-L6-v2",
            "config" : {
                "dqlQuery" : "{ value: material.Value}",
                "template": "{{value}} "
            }
        },
        {
            "entityType":"measurement",
            "attribute":"embedding",
            "index":"hnsw",
            "provider": "huggingface",
            "model":"sentence-transformers/all-MiniLM-L6-v2",
            "config" : {
                "dqlQuery" : "{ value: measurement.Value}",
                "template": "{{value}} "
            }
        }
    ]

for embedding_def in embedding_config:                
    buildEmbeddings(
            embedding_def,
            only_missing = True
            )
    print(f"Embeddings done for {embedding_def['entityType']}.{embedding_def['attribute']}")

### Querying the graph using Dgraph similarity function

In [304]:
sentence = "looking for something to make my home pretty"

# Get the sentence embedding with the same model
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
sentence_embedding = model.encode(sentence).tolist()

# Use Dgraph similar_to function to find similar categories and use Graph relations to get the products for this category
txn = client.txn(read_only=True)
query = f'''
  {{  
    result(func: similar_to(category.embedding,3,"{sentence_embedding}")) {{ 
        category:category.Value 
        products:~Product.category (first:3){{ 
            name:Product.Name
        }} 
    }} 
  }}'''
try:
    res = txn.query(query)
    data = json.loads(res.json)
    print(json.dumps(data,indent=4))
finally:
    txn.discard()


{
    "result": [
        {
            "category": "wedding decor",
            "products": [
                {
                    "name": "Romantic LED Light Valentine's Day Sign"
                }
            ]
        },
        {
            "category": "home decor",
            "products": [
                {
                    "name": "Fall Pillow Covers"
                }
            ]
        },
        {
            "category": "home garden balcony decor",
            "products": [
                {
                    "name": "Flower Pot Stand"
                }
            ]
        }
    ]
}


### Extracting entities from the prompt

In the previous example, we assumed that the question was about a category. 

We can go further and use an LLM model to analyze the user prompt and determine the correct criteria to use before querying the graph structure.

We will use OpenAI and a prompt build with our knowledge of the graph structure, i.e the description of the entities and predicates that can be found in the graph (aka ontology).

In [None]:
entities = [
    {
        "entity_name": "Product",
        "description": "Item detailed type",
        "predicates": {
            "category" : {"description": "Item category, for example 'home decoration', 'women clothing', 'office supply'"},
            "color" : {"description": "color of the item"},
            "brand": {"description": "if present, brand of the item"},
            "characteristic": {"description": "if present, item characteristics, for example 'waterproof', 'adhesive', 'easy to use'"},
            "measurement": {"description": "if present, dimensions of the item"},
            "age_group": {"description": "target age group for the product, one of 'babies', 'children', 'teenagers', 'adults'."}
        }    
    }
]
def ontologyPrompt(ontology):
    # Create a textual description of the ontology to help prompting LLM
    # The graph database has the following entities and predicates:
    entities = [ f'\'{e["entity_name"]}\'' for e in ontology]
    list_entities = ", ".join(entities)
    s = f"Identify if the user question is about one of the entities {list_entities}."
    s += "\nIdentify criteria about predicates depending on the entity."
    for e in ontology:
        pred = [ f'\'{p}\'' for p in e["predicates"]]
        pred_list = ", ".join(pred)
        s+= f'\nFor \'{e["entity_name"]}\' look for:'
        for p in e["predicates"]:
            s+= f'\n- \'{p}\': {e["predicates"][p]["description"]}'
    return(s)


In [None]:
system_prompt = f'''
    You are analyzing user prompt to fetch information from a knowledge graph. 

    {ontologyPrompt(entities)}

    Return a json object following the example:
    {{
        "entity": "product",
        "intent": "one of 'list', 'count'",
        criteria: [
        {{ "predicate": "category", "value": "clothing"}},
        {{ "predicate": "color", "value": "blue"}},
        {{ "predicate": "age_group", "value": "adults"}}   
        ]
    }}
    
    If there are no relevant entities in the user prompt, return an empty json object.
'''
print(system_prompt)

In [None]:
from openai import OpenAI
llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

# Define the entities to look for
def text_to_intent(prompt, model="gpt-4o-mini"):
    completion = llm.chat.completions.create(
        model=model,
        temperature=0,
        response_format= {
            "type": "json_object"
        },
    messages=[
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": prompt
        }
        ]
    )
    intent = json.loads(completion.choices[0].message.content)
    intent['prompt'] = prompt
    return intent

In [305]:
example_queries = [
    "do you have clothes for teenagers in dark colors?",
    "Which pink items are suitable for children?",
    "Help me find gardening gear that is waterproof",
    "I'm looking for a bench with dimensions 100x50 for my living room"
]

intent_list = [text_to_intent(q) for q in example_queries]

print(json.dumps(intent_list, indent=4))

[
    {
        "entity": "product",
        "intent": "list",
        "criteria": [
            {
                "predicate": "category",
                "value": "clothing"
            },
            {
                "predicate": "color",
                "value": "dark"
            },
            {
                "predicate": "age_group",
                "value": "teenagers"
            }
        ],
        "prompt": "do you have clothes for teenagers in dark colors?"
    },
    {
        "entity": "product",
        "intent": "list",
        "criteria": [
            {
                "predicate": "color",
                "value": "pink"
            },
            {
                "predicate": "age_group",
                "value": "children"
            }
        ],
        "prompt": "Which pink items are suitable for children?"
    },
    {
        "entity": "product",
        "intent": "list",
        "criteria": [
            {
                "predicate": "characteristic",
 

### Generating queries

Now that we know what to look for, we can generate the corresponding Cypher queries to query our database. 

However, the entities extracted might not be an exact match with the data we have, so we will use the GDS cosine similarity function to return products that have relationships with entities similar to what the user is asking.

In [309]:
# use same embedding model for user input and for the searched entities

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def create_embedding(text):
    # print (f"create embedding for {text}")
    return huggingfaceEmbeddings(model,[text])[0]

# algo to create query
# for each criteria, compute an embedding of the criteria value.
# build a sequence of var block to find the most similar node (category, characteristic, brand etc ...).
# build a filter to keep only the Product with the corresponding category, characteristic, brand etc ...

def intent_to_dql(intent):
    vect = []
    vars = []
    filters = []
    variables = {}
    for criteria in intent['criteria']:
        variables[f"${criteria['predicate']}vect"] = f"{create_embedding(criteria['value'])}"
        vect.append(f"${criteria['predicate']}vect: float32vector")
        vars.append(f"{criteria['predicate']} as var(func:similar_to({criteria['predicate']}.embedding,1,${criteria['predicate']}vect))")
        filters.append(f"uid_in(Product.{criteria['predicate']}, uid({criteria['predicate']}))")
    all_filters = "\n AND ".join(filters)
    all_vars = "\n".join(vars)    
    query = f"""
      query test({','.join(vect)}){{
          {all_vars}
          products(func:type(Product)) @filter( {all_filters} ) {{
            name:Product.Name
            title:Product.Title
            age_group:Product.age_group  {{
               value:age_group.Value
            }}
            brand:Product.brand  {{
               value:brand.Value
            }}
            color:Product.color  {{
               value:color.Value
            }}
            category:Product.category  {{
               value:category.Value
            }}
            characteristic:Product.characteristic  {{
               value:characteristic.Value
            }}
            material:Product.material  {{
               value:material.Value
            }}
            measurement:Product.measurement  {{
               value:measurement.Value
            }}

          }}
      }}
    """
    return {"query":query,"variables":variables}


### Generating a response from the retrieved sub-graph

In [None]:


def rag(prompt, payload):
    model="gpt-4o-mini"
    rag_prompt = f'''
        You are suggesting products based on user input and available items.
        Reply to the user with suggestions from the following data that match the criteria
        {payload}
        if possible explain why the items are suggested.
        If there are no relevant items reply that we don't have any items that match the criteria.
    '''
    completion = llm.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "system",
                "content": rag_prompt
            },
            {
                "role": "user",
                "content": prompt
            }
            ]
        )
    return completion.choices[0].message.content

In [None]:

def reply(sentence):
    intent = text_to_intent(sentence)
    dql = intent_to_dql(intent)
    res = client.txn(read_only=True).query(dql["query"], variables=dql["variables"])
    payload = json.loads(res.json)
    return rag(sentence, payload)


In [313]:
example_queries = [
    "Which pink items are suitable for children?",
    "Do you have a helmet with anti allergic padding?",
]
for q in example_queries:
    print()
    print(f"> {q}")
    print()
    r = reply(q)
    print(r)


> Which pink items are suitable for children?

I have two great suggestions for pink items that are suitable for children:

1. **Suitcase Music Box**
   - **Title**: Suitcase Music Box, Mini Music Box Clockwork Music Box for Children
   - **Age Group**: Children
   - **Color**: Pink
   - **Category**: Toys & Games
   - **Characteristics**: This music box features a make-up mirror, jewelry box functionality, requires no batteries, operates with a clockwork mechanism, and has a storage compartment. It's a delightful toy that can also serve as a charming decorative piece for a child's room.

2. **Unicorn Curtains**
   - **Title**: Eiichuang Unicorn Curtains Rod Pocket Blackout Cute Cartoon Pink Unicorn Wearing a Crown Mermaid Pattern Print Room Darkening Window Drapes for Kids Girls Bedroom Nursery, 2 Panels Set, 29 x 63 Inch
   - **Age Group**: Children
   - **Color**: Pink
   - **Category**: Home Decoration
   - **Characteristics**: These curtains feature a fun unicorn design and are r

## Conclusion
You can leverage the knowledge stored in a Graph database to offer accurate responses or recommendations to your user requests.

In this example, we have used AI techniques and Graph capabilities to achieve this result.

- use a Large Language Model and the graph metadata (ontology) to analyze the user prompt and build an "intent".
- use a small language model to generate text embeddings for graph entities.
- use graph similarity search function to find matching criteria
-  use graph traversal and filters to identify matching items
- use an LLM to generate a textual response based on the user prompt and the information retrieved from the graph database.

This example provides a general working flow of RAG over Graph use case. It can be improved in various points, including
- use a different embedding model
- create a more complex intent structure covering aggregation, counting and complex criteria.( E.g could you build an intent for the question "How many products do you have in home decoration under 100$?" )
- train a model to generate the query instead of crafting it.
- train a model to generate intent from user input instead of using an LLM
- ...




