Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neo4j Advanced RAG template #12794

Merged
merged 9 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions templates/neo4j-advanced-rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# neo4j-advanced-rag

This template allows you to balance precise embeddings and context retention by implementing advanced retrieval strategies.

## Strategies

1. **Typical RAG**:
- Traditional method where the exact data indexed is the data retrieved.
2. **Parent retriever**:
- Instead of indexing entire documents, data is divided into smaller chunks, referred to as Parent and Child documents.
- Child documents are indexed for better representation of specific concepts, while parent documents is retrieved to ensure context retention.
3. **Hypothetical Questions**:
- Documents are processed to determine potential questions they might answer.
- These questions are then indexed for better representation of specific concepts, while parent documents are retrieved to ensure context retention.
4. **Summaries**:
- Instead of indexing the entire document, a summary of the document is created and indexed.
- Similarly, the parent document is retrieved in a RAG application.

## Environment Setup

You need to define the following environment variables

```
OPENAI_API_KEY=<YOUR_OPENAI_API_KEY>
NEO4J_URI=<YOUR_NEO4J_URI>
NEO4J_USERNAME=<YOUR_NEO4J_USERNAME>
NEO4J_PASSWORD=<YOUR_NEO4J_PASSWORD>
```

## Populating with data

If you want to populate the DB with some example data, you can run `python ingest.py`.
The script process and stores sections of the text from the file `dune.txt` into a Neo4j graph database.
First, the text is divided into larger chunks ("parents") and then further subdivided into smaller chunks ("children"), where both parent and child chunks overlap slightly to maintain context.
After storing these chunks in the database, embeddings for the child nodes are computed using OpenAI's embeddings and stored back in the graph for future retrieval or analysis.
For every parent node, hypothetical questions and summaries are generated, embedded, and added to the database.
Additionally, a vector index for each retrieval strategy is created for efficient querying of these embeddings.

*Note that ingestion can take a minute or two due to LLMs velocity of generating hypothetical questions and summaries.*

## Usage

To use this package, you should first have the LangChain CLI installed:

```shell
pip install -U "langchain-cli[serve]"
```

To create a new LangChain project and install this as the only package, you can do:

```shell
langchain app new my-app --package neo4j-advanced-rag
```

If you want to add this to an existing project, you can just run:

```shell
langchain app add neo4j-advanced-rag
```

And add the following code to your `server.py` file:
```python
from neo4j_advanced_rag import chain as neo4j_advanced_chain

add_routes(app, neo4j_advanced_chain, path="/neo4j-advanced-rag")
```

(Optional) Let's now configure LangSmith.
LangSmith will help us trace, monitor and debug LangChain applications.
LangSmith is currently in private beta, you can sign up [here](https://smith.langchain.com/).
If you don't have access, you can skip this section

```shell
export LANGCHAIN_TRACING_V2=true
export LANGCHAIN_API_KEY=<your-api-key>
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default"
```

If you are inside this directory, then you can spin up a LangServe instance directly by:

```shell
langchain serve
```

This will start the FastAPI app with a server is running locally at
[http://localhost:8000](http://localhost:8000)

We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
We can access the playground at [http://127.0.0.1:8000/neo4j-advanced-rag/playground](http://127.0.0.1:8000/neo4j-advanced-rag/playground)

We can access the template from code with:

```python
from langserve.client import RemoteRunnable

runnable = RemoteRunnable("http://localhost:8000/neo4j-advanced-rag")
```
95 changes: 95 additions & 0 deletions templates/neo4j-advanced-rag/dune.txt

Large diffs are not rendered by default.

203 changes: 203 additions & 0 deletions templates/neo4j-advanced-rag/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from pathlib import Path
from typing import List

from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
from langchain.text_splitter import TokenTextSplitter
from neo4j.exceptions import ClientError

txt_path = Path(__file__).parent / "dune.txt"

graph = Neo4jGraph()

# Embeddings & LLM models
embeddings = OpenAIEmbeddings()
embedding_dimension = 1536
llm = ChatOpenAI(temperature=0)

# Load the text file
loader = TextLoader(str(txt_path))
documents = loader.load()

# Ingest Parent-Child node pairs
parent_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
child_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=24)
parent_documents = parent_splitter.split_documents(documents)

for i, parent in enumerate(parent_documents):
child_documents = child_splitter.split_documents([parent])
params = {
"parent_text": parent.page_content,
"parent_id": i,
"parent_embedding": embeddings.embed_query(parent.page_content),
"children": [
{
"text": c.page_content,
"id": f"{i}-{ic}",
"embedding": embeddings.embed_query(c.page_content),
}
for ic, c in enumerate(child_documents)
],
}
# Ingest data
graph.query(
"""
MERGE (p:Parent {id: $parent_id})
SET p.text = $parent_text
WITH p
CALL db.create.setVectorProperty(p, 'embedding', $parent_embedding)
YIELD node
WITH p
UNWIND $children AS child
MERGE (c:Child {id: child.id})
SET c.text = child.text
MERGE (c)<-[:HAS_CHILD]-(p)
WITH c, child
CALL db.create.setVectorProperty(c, 'embedding', child.embedding)
YIELD node
RETURN count(*)
""",
params,
)
# Create vector index for child
try:
graph.query(
"CALL db.index.vector.createNodeIndex('parent_document', "
"'Child', 'embedding', $dimension, 'cosine')",
{"dimension": embedding_dimension},
)
except ClientError: # already exists
pass
# Create vector index for parents
try:
graph.query(
"CALL db.index.vector.createNodeIndex('typical_rag', "
"'Parent', 'embedding', $dimension, 'cosine')",
{"dimension": embedding_dimension},
)
except ClientError: # already exists
pass
# Ingest hypothethical questions


class Questions(BaseModel):
"""Generating hypothetical questions about text."""

questions: List[str] = Field(
...,
description=(
"Generated hypothetical questions based on " "the information from the text"
),
)


questions_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
(
"You are generating hypothetical questions based on the information "
"found in the text. Make sure to provide full context in the generated "
"questions."
),
),
(
"human",
(
"Use the given format to generate hypothetical questions from the "
"following input: {input}"
),
),
]
)

question_chain = create_structured_output_chain(Questions, llm, questions_prompt)

for i, parent in enumerate(parent_documents):
questions = question_chain.run(parent.page_content).questions
params = {
"parent_id": i,
"questions": [
{"text": q, "id": f"{i}-{iq}", "embedding": embeddings.embed_query(q)}
for iq, q in enumerate(questions)
if q
],
}
graph.query(
"""
MERGE (p:Parent {id: $parent_id})
WITH p
UNWIND $questions AS question
CREATE (q:Question {id: question.id})
SET q.text = question.text
MERGE (q)<-[:HAS_QUESTION]-(p)
WITH q, question
CALL db.create.setVectorProperty(q, 'embedding', question.embedding)
YIELD node
RETURN count(*)
""",
params,
)
# Create vector index
try:
graph.query(
"CALL db.index.vector.createNodeIndex('hypothetical_questions', "
"'Question', 'embedding', $dimension, 'cosine')",
{"dimension": embedding_dimension},
)
except ClientError: # already exists
pass

# Ingest summaries

summary_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
(
"You are generating concise and accurate summaries based on the "
"information found in the text."
),
),
(
"human",
("Generate a summary of the following input: {question}\n" "Summary:"),
),
]
)

summary_chain = summary_prompt | llm

for i, parent in enumerate(parent_documents):
summary = summary_chain.invoke({"question": parent.page_content}).content
params = {
"parent_id": i,
"summary": summary,
"embedding": embeddings.embed_query(summary),
}
graph.query(
"""
MERGE (p:Parent {id: $parent_id})
MERGE (p)-[:HAS_SUMMARY]->(s:Summary)
SET s.text = $summary
WITH s
CALL db.create.setVectorProperty(s, 'embedding', $embedding)
YIELD node
RETURN count(*)
""",
params,
)
# Create vector index
try:
graph.query(
"CALL db.index.vector.createNodeIndex('summary', "
"'Summary', 'embedding', $dimension, 'cosine')",
{"dimension": embedding_dimension},
)
except ClientError: # already exists
pass
10 changes: 10 additions & 0 deletions templates/neo4j-advanced-rag/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from neo4j_advanced_rag.chain import chain

if __name__ == "__main__":
original_query = "What is the plot of the Dune?"
print(
chain.invoke(
{"question": original_query},
{"configurable": {"strategy": "parent_document"}},
)
)
3 changes: 3 additions & 0 deletions templates/neo4j-advanced-rag/neo4j_advanced_rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from neo4j_advanced_rag.chain import chain

__all__ = ["chain"]
51 changes: 51 additions & 0 deletions templates/neo4j-advanced-rag/neo4j_advanced_rag/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from operator import itemgetter

from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import ConfigurableField, RunnableParallel

from neo4j_advanced_rag.retrievers import (
hypothetic_question_vectorstore,
parent_vectorstore,
summary_vectorstore,
typical_rag,
)

template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

model = ChatOpenAI()

retriever = typical_rag.as_retriever().configurable_alternatives(
ConfigurableField(id="strategy"),
default_key="typical_rag",
parent_strategy=parent_vectorstore.as_retriever(),
hypothetical_questions=hypothetic_question_vectorstore.as_retriever(),
summary_strategy=summary_vectorstore.as_retriever(),
)

chain = (
RunnableParallel(
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question"),
}
)
| prompt
| model
| StrOutputParser()
)


# Add typing for input
class Question(BaseModel):
question: str


chain = chain.with_types(input_type=Question)
Loading
Loading