<a href="https://colab.research.google.com/github/gtagoh/Knowledge-Graphs/blob/main/notebooks/M3_enhancing_rag_with_graph_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Enhancing RAG-based Applications with Knowledge Graphs

## A Practical Guide to Constructing and Leveraging Knowledge Graphs in RAG Applications with Neo4j and LangChain

Graph retrieval-augmented generation (Graph RAG) is emerging as a powerful enhancement to traditional vector search methods in RAG-based applications. By combining the power of graph databases and vector-based approaches, you can achieve more accurate, context-rich results.

Graph databases store data as nodes and relationships, representing complex and interconnected information in a structured way. This makes them highly effective for capturing intricate relationships and attributes across diverse data types. On the other hand, vector databases excel at managing unstructured data by converting it into high-dimensional vectors, making it easier to search for semantic similarities. Integrating these two approaches allows RAG applications to leverage the strengths of both graph and vector databases, which we demonstrate in this tutorial.

## Prerequisites

- **Neo4j**: Set up a Neo4j database instance, either using the free Neo4j Aura cloud service or by installing Neo4j Desktop locally.
- **OpenAI API Key**: Obtain an OpenAI API key to access the models we use in this tutorial.
- **LangChain, Neo4j, and other dependencies**: Install the necessary Python packages.

In [1]:
# Uncomment and run this command to install the necessary packages
!pip install --upgrade --quiet langchain langchain-community langchain-openai langchain-experimental neo4j wikipedia tiktoken yfiles_jupyter_graphs PyPDF2 pypdf

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.4/62.4 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.2/209.2 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m312.3/312.3 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.3/302.3 kB[0m [31m15.6 MB/s[0m eta [36m0:00:0

In [None]:
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Tuple, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
import os
from langchain_community.graphs import Neo4jGraph
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from langchain_core.runnables import ConfigurableField, RunnableParallel, RunnablePassthrough

try:
  import google.colab
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


### Setting Up Environment Variables

You'll need to set up environment variables for your OpenAI API key and Neo4j credentials. Add the following lines to your code to configure them:

In [None]:
os.environ["OPENAI_API_KEY"] = ""
os.environ["NEO4J_URI"] = "neo4j+s://fae385d3.databases.neo4j.io"
os.environ["NEO4J_USERNAME"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = ""

graph = Neo4jGraph()

### Data Ingestion and Knowledge Graph Construction

In this tutorial, we will use an academic research paper (for example, an "attention mechanism" paper) as our data source. We will ingest the text and construct a knowledge graph using LangChain and Neo4j.

In [None]:
from langchain.document_loaders import PyPDFLoader

# Replace 'path/to/your/document.pdf' with the actual path to your PDF file
pdf_path = "/content/attention.pdf"

# Initialize the loader with the PDF path
loader = PyPDFLoader(pdf_path)

# Load the documents
raw_documents = loader.load()

### Split the Document into Chunks

In [None]:
# Define chunking strategy
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
documents = text_splitter.split_documents(raw_documents[:3])

Now it's time to construct a graph based on the retrieved documents. For this purpose, we have implemented an `LLMGraphTransformermodule` that significantly simplifies constructing and storing a knowledge graph in a graph database.

In [None]:
llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125") # gpt-4-0125-preview occasionally has issues
llm_transformer = LLMGraphTransformer(llm=llm)

graph_documents = llm_transformer.convert_to_graph_documents(documents)
graph.add_graph_documents(
    graph_documents,
    baseEntityLabel=True,
    include_source=True
)

 Note that the quality of the generated graph largely depends on the model used—ideally, you should select the most capable model. The LLM graph transformers generate graph documents, which can be imported into Neo4j using the `add_graph_documents` method. The `baseEntityLabel` parameter adds an `__Entity__` label to each node, enhancing indexing and query performance. Additionally, the `include_source` parameter helps link nodes to their source documents, supporting data traceability and context.


### Constructing and Visualizing the Knowledge Graph in Neo4j

In [None]:
# directly show the graph resulting from the given Cypher query
default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"

def showGraph(cypher: str = default_cypher):
    # create a neo4j session to run queries
    driver = GraphDatabase.driver(
        uri = os.environ["NEO4J_URI"],
        auth = (os.environ["NEO4J_USERNAME"],
                os.environ["NEO4J_PASSWORD"]))
    session = driver.session()
    widget = GraphWidget(graph = session.run(cypher).graph())
    widget.node_label_mapping = 'id'
    #display(widget)
    return widget

showGraph()

### Hybrid Retrieval for RAG

After generating the knowledge graph, we will implement a hybrid retrieval approach that combines vector and keyword indexes with graph retrieval for RAG applications.

![retrieval](https://raw.githubusercontent.com/tomasonjo/blogs/master/graphhybrid.png)

The diagram illustrates a retrieval process beginning with a user posing a question, which is then directed to an RAG retriever. This retriever employs keyword and vector searches to search through unstructured text data and combines it with the information it collects from the knowledge graph. Since Neo4j features both keyword and vector indexes, you can implement all three retrieval options with a single database system. The collected data from these sources is fed into an LLM to generate and deliver the final answer.
## Unstructured data retriever
You can use the Neo4jVector.from_existing_graph method to add both keyword and vector retrieval to documents. This method configures keyword and vector search indexes for a hybrid search approach, targeting nodes labeled Document. Additionally, it calculates text embedding values if they are missing.


In [None]:
vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)

The vector index can then be called with the similarity_search method.

### Graph retriever
Configuring graph retrieval, on the other hand, is more complex but provides greater flexibility. In this example, we will use a full-text index to locate relevant nodes and then retrieve their immediate neighbors.


![graph](https://raw.githubusercontent.com/tomasonjo/blogs/master/neighbor.png)

The graph retriever begins by identifying relevant entities in the input. For simplicity, we instruct the LLM to identify people, organizations, and locations. To do this, we will utilize LCEL with the newly introduced `with_structured_output` method.








In [None]:
# Retriever

graph.query(
    "CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")

# Extract entities from text
class Entities(BaseModel):
    """Identifying information about entities."""

    names: List[str] = Field(
        ...,
        description="All the person, organization, or business entities that "
        "appear in the text",
    )

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are extracting organization and person entities from the text.",
        ),
        (
            "human",
            "Use the given format to extract information from the following "
            "input: {question}",
        ),
    ]
)

entity_chain = prompt | llm.with_structured_output(Entities)

Let's test it out:

In [None]:
entity_chain.invoke({"question": "Which component in the Transformer uses multi-head attention?"}).names

['Transformer', 'multi-head attention']

Great, now that we can detect entities in the question, let's use a full-text index to map them to the knowledge graph. First, we need to define a full-text index and a function that will generate full-text queries that allow a bit of misspelling, which we won't go into much detail here.

In [None]:
def generate_full_text_query(input: str) -> str:
    """
    Generate a full-text search query for a given input string.

    This function constructs a query string suitable for a full-text search.
    It processes the input string by splitting it into words and appending a
    similarity threshold (~2 changed characters) to each word, then combines
    them using the AND operator. Useful for mapping entities from user questions
    to database values, and allows for some misspelings.
    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()

# Fulltext index query
def structured_retriever(question: str) -> str:
    """
    Collects the neighborhood of entities mentioned
    in the question
    """
    result = ""
    entities = entity_chain.invoke({"question": question})
    for entity in entities.names:
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
            YIELD node,score
            CALL {
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS output
            }
            RETURN output LIMIT 50
            """,
            {"query": generate_full_text_query(entity)},
        )
        result += "\n".join([el['output'] for el in response])
    return result

The `structured_retriever` function starts by detecting entities in the user question. Next, it iterates over the detected entities and uses a Cypher template to retrieve the neighborhood of relevant nodes. Let's test it out!

In [None]:
print(structured_retriever("Which component in the Transformer uses multi-head attention?"))



Transformer - IMPROVES -> Translation Quality
Transformer - REPLACES -> Attention Mechanisms
Transformer - ALLOWS -> Parallelization
Transformer - RELIES_ON -> Self-Attention
Transformer - RELYING_ON -> Self-Attention
Transformer - HAS_COMPONENT -> Encoder
Transformer - HAS_COMPONENT -> Decoder
Transformer - HAS_STACK -> Encoder
Transformer - HAS_STACK -> Decoder
Transformer - HAS_FUNCTION -> Attention
Transformer - IMPROVE -> Translation Quality
Transformer - IMPROVE -> Training
Transformer - ALLOW -> Parallelization
Transformer - USE -> P100 Gpus
Transformer - REDUCE -> Operations
Transformer - REDUCE -> Effective Resolution
Transformer - COMPUTE -> Input
Transformer - COMPUTE -> Output
Transformer - COMPUTE -> Representations
Transformer - FIRST -> Transduction Model
Transformer - RELY_ON -> Attention Mechanisms
Transformer - RELY_ON -> Self-Attention
Transformer - NOT_USE -> Convolution
Transformer - NOT_USE -> Sequence-Aligned Rnns
Transformer - ALLOW_FOR -> Parallelization
Transf

## Final retriever
As we mentioned at the start, we'll combine the unstructured and graph retriever to create the final context that will be passed to an LLM.

In [None]:
def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
{structured_data}
Unstructured data:
{"#Document ". join(unstructured_data)}
    """
    return final_data

As we are dealing with Python, we can simply concatenate the outputs using the f-string.
## Defining the RAG chain
We have successfully implemented the retrieval component of the RAG. First, we will introduce the query rewriting part that allows conversational follow up questions.


In [None]:
# Condense a chat history and follow-up question into a standalone question
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""  # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer

_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | ChatOpenAI(temperature=0)
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(lambda x : x["question"]),
)

Next, we introduce a prompt that leverages the context provided by the integrated hybrid retriever to produce the response, completing the implementation of the RAG chain.

In [None]:
template = """Answer the question based only on the following context:
{context}

Question: {question}
Use natural language and be concise.
Answer:"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)

Finally, we can go ahead and test our hybrid RAG implementation.

In [None]:
chain.invoke({"question": "Which component in the Transformer uses multi-head attention?"})

Search query: Which component in the Transformer uses multi-head attention?




'The encoder and decoder components in the Transformer use multi-head attention.'