# Graph RAG

Example of how to build and query a graph database for RAG.

# Setup

# install poetry if you don't have it yet
`brew install poetry`

# install docker
If you're on a Mac, you'll want the docker desktop application, available here: https://www.docker.com/products/docker-desktop/

# install langchain, etc. 
Make sure you're in the directory that holds `poetry.lock`

`poetry install`

`source $(poetry env info --path)/bin/activate`

Copy `.env.example` to `.env` and update the variables as documented in that file.

# run neo4j
```
docker run \
    -p 7474:7474 -p 7687:7687 \
    -v $PWD/data:/data -v $PWD/plugins:/plugins \
    --name neo4j-apoc \
    -e NEO4J_AUTH=$NEO4J_USERNAME/$NEO4J_PASSWORD \
    -e NEO4J_apoc_export_file_enabled=true \
    -e NEO4J_apoc_import_file_enabled=true \
    -e NEO4J_apoc_import_file_use__neo4j__config=true \
    -e NEO4JLABS_PLUGINS=\[\"apoc\"\] \
    -e NEO4J_dbms_security_procedures_unrestricted=apoc.\\\* \
    neo4j:latest
```
(this command is in `run_neo4j_docker.sh`)

In [1]:
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Tuple, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
import os
# from langchain_community.graphs import GraphDatabase
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from langchain_core.runnables import ConfigurableField, RunnableParallel, RunnablePassthrough

try:
  import google.colab
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass
from dotenv import load_dotenv
import os

In [21]:
# validate neo4j
OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
NEO4J_KEY = os.getenv('NEO4J_KEY')
NEO4J_AUTH = (NEO4J_USERNAME, NEO4J_PASSWORD)

with GraphDatabase.driver(NEO4J_URI, auth=NEO4J_AUTH) as driver:
    driver.verify_connectivity()

In [53]:
# Read the wikipedia articles
article_keys = [
  "Star Trek: The Next Generation", 
  "List of Star Trek: The Next Generation episodes",
]
extended_keys = [
  "Data",
  "Wesley Crusher",
  "Ro Laren",
  "William Riker",
  "Geordi La Forge",
  "Deanna Troi",
  "Guinan (Star Trek)",
  "Beverly Crusher",
  "Worf",
  "Tasha Yar",
  "Spock",
  "Jean-Luc Picard",
  "Miles O'Brien",
  "Reginald Barclay",
  "Deep Space Nine"
]

# if you're a real Star Trek nerd, uncomment this line:
#
# article_keys.extend(extended_keys)
#
# be aware that this will add an hour or more of additional
# node generation time in the "convert_to_graph_documents"
# step, and it will cost some actual money if you're paying
# for LLM access.

raw_documents = []
batch_size = 2
current_key_index = 0
while current_key_index < len(article_keys):
  for i in range(current_key_index, batch_size)
    query = article_keys[i]
    raw_documents.extend(WikipediaLoader(query=query).load())
    current_key_index += 1

# Define chunking strategy
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
documents = text_splitter.split_documents(raw_documents)



  lis = BeautifulSoup(html).find_all('li')


In [11]:
OPENAI_MODEL_NAME="llama3-8b-8192"
llm=ChatOpenAI(
    temperature=0.0,
    model_name=OPENAI_MODEL_NAME,
    openai_api_key=OPENAI_API_KEY,
    openai_api_base=OPENAI_BASE_URL,
  )

In [17]:
llm_transformer = LLMGraphTransformer(llm=llm)
from langchain_community.graphs.neo4j_graph import Neo4jGraph

graph = Neo4jGraph(URI, USERNAME, PASSWORD)

def add_query(query):
  text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
  raw_documents = WikipediaLoader(query=query).load()
  print(f'{len(raw_documents)} raw documents for "{query}"')
  documents = text_splitter.split_documents(raw_documents)
  print(f'calling transformer to convert {len(documents)} documents')
  graph_documents = llm_transformer.convert_to_graph_documents(documents)
  print(f'Adding {len(graph_documents)} graph documents')
  graph.add_graph_documents(
      graph_documents,
      baseEntityLabel=True,
      include_source=True
  )


In [19]:
def build_graph(queries):
  from time import sleep
  sleep_time = 1
  attempting = True
  for query in extended_keys:
    attempting = True
    sleep_time = 1
    while attempting and sleep_time <= 32:
      try:
        add_query(query)
        attempting = False
      except:
        print(f'error, sleeping for {sleep_time} seconds')
        sleep(sleep_time)
        sleep_time *= 2

extended_keys = [
  "Data",
  "Wesley Crusher",
  "Ro Laren",
  "William Riker",
  "Geordi La Forge",
  "Deanna Troi",
  "Guinan (Star Trek)",
  "Beverly Crusher",
  "Worf",
  "Tasha Yar",
  "Spock",
  "Jean-Luc Picard",
  "Miles O'Brien",
  "Reginald Barclay",
  "Deep Space Nine"
]

# build_graph(extended_keys)

# Unstructured data retriever
You can use the `Neo4jVector.from_existing_graph` method to add both keyword and vector retrieval to documents. This method configures keyword and vector search indexes for a hybrid search approach, targeting nodes labeled `Document`. Additionally, it calculates text embedding values if they are missing.

In [22]:
# existing_graph = Neo4jVector.from_existing_graph(
#     embedding=OpenAIEmbeddings(),
#     url=url,
#     username=username,
#     password=password,
#     index_name="person_index",
#     node_label="Person",
#     text_node_properties=["name", "location"],
#     embedding_node_property="embedding",
# )

embeddings = OpenAIEmbeddings()
vector_index = Neo4jVector.from_existing_graph(
    embeddings,
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD,    
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)

BadRequestError: Error code: 400 - {'error': {'message': "'input' : input must be a string or an array of strings", 'type': 'invalid_request_error'}}

# Graph retriever
On the other hand, configuring a graph retrieval is more involved but offers more freedom. In this example, we will use a full-text index to identify relevant nodes and then return their direct neighborhood.

![Graph retriever. Image from LangChain.](./1_z0pYA_dSNG_yTYE6Rr7CQA.png)

The graph retriever starts by identifying relevant entities in the input. For simplicity, we instruct the LLM to identify people, organizations, and locations. To achieve this, we will use LCEL with the newly added `with_structured_output` method to achieve this.

In [13]:
# Extract entities from text
class Entities(BaseModel):
    """Identifying information about entities."""

    names: List[str] = Field(
        ...,
        description="All the person, character, organization, or starships \
        that appear in the text",
    )

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are extracting character and person entities from the \
            text.",
        ),
        (
            "human",
            "Use the given format to extract information from the \
             following"
            "input: {question}",
        ),
    ]
)

entity_chain = prompt | llm.with_structured_output(Entities)

In [17]:
# gets an entity
entity_chain.invoke({"question": "Who is Captain Jean-Luc Picard?"}).names

['Captain Jean-Luc Picard']

now that we can detect entities in the question, let’s use a full-text index to map them to the knowledge graph. First, we need to define a full-text index and a function that will generate full-text queries that allow a bit of misspelling, which we won’t go into much detail here.

In [18]:
graph.query(
    "CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")

def generate_full_text_query(input: str) -> str:
    """
    Generate a full-text search query for a given input string.

    This function constructs a query string suitable for a full-text
    search. It processes the input string by splitting it into words and 
    appending a similarity threshold (~2 changed characters) to each
    word, then combines them using the AND operator. Useful for mapping
    entities from user questions to database values, and allows for some 
    misspelings.
    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()

In [19]:
# Fulltext index query
def structured_retriever(question: str) -> str:
    """
    Collects the neighborhood of entities mentioned
    in the question
    """
    result = ""
    entities = entity_chain.invoke({"question": question})
    for entity in entities.names:
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, 
            {limit:2})
            YIELD node,score
            CALL {
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS 
              output
              UNION
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS 
              output
            }
            RETURN output LIMIT 50
            """,
            {"query": generate_full_text_query(entity)},
        )
        result += "\n".join([el['output'] for el in response])
    return result

The `structured_retriever` function starts by detecting entities in the user question. Next, it iterates over the detected entities and uses a Cypher template to retrieve the neighborhood of relevant nodes. Let’s test it out!

In [30]:
answers = structured_retriever("Who is Picard?")
print('\n'.join(answers.split('\n')[:10]))

Star Trek: The Next Generation - CREATED_BY -> Gene Roddenberry
Star Trek: The Next Generation - INSPIRED_BY -> Star Trek: The Original Series
Uss Enterprise (Ncc-1701-D) - FEATURED_IN -> Star Trek: The Next Generation
Jean-Luc Picard - PORTRAYED_BY -> Patrick Stewart
William Riker - PORTRAYED_BY -> Jonathan Frakes
Data - PORTRAYED_BY -> Brent Spiner
Worf - PORTRAYED_BY -> Michael Dorn
Geordi La Forge - PORTRAYED_BY -> Levar Burton
Deanna Troi - PORTRAYED_BY -> Marina Sirtis
Beverly Crusher - PORTRAYED_BY -> Gates Mcfadden


In [31]:
#Combine the unstructured and graph retriever to create the final context that will be passed to an LLM.

def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
{structured_data}
Unstructured data:
{"#Document ". join(unstructured_data)}
    """
    return final_data

This allows for follow-up questions:

In [42]:
# Condense a chat history and follow-up question into a standalone question
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""  # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer

_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | ChatOpenAI(temperature=0)
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(lambda x : x["question"]),
)

# Defining the RAG chain
We have successfully implemented the retrieval component of the RAG. Next, we introduce a prompt that leverages the context provided by the integrated hybrid retriever to produce the response, completing the implementation of the RAG chain.

In [44]:
template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)



In [50]:
chain.invoke({"question": "Who was Picard's security officer?"})

Search query: Who was Picard's security officer?


"Picard's security officer was Worf, portrayed by Michael Dorn."

In [51]:
# follow-up questions:
chain.invoke(
    {
        "question": "What race is he?",
        "chat_history": [("Who was Picard's security officer?", "Picard's security officer was Worf, portrayed by Michael Dorn.")],
    }
)

Search query: What race is Worf?


"Worf, Picard's security officer, is a Klingon."