-
Notifications
You must be signed in to change notification settings - Fork 13.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Neo4j conversation cypher template (#12927)
Adding custom graph memory to Cypher chain --------- Co-authored-by: Erick Friis <erick@langchain.dev>
- Loading branch information
Showing
8 changed files
with
1,770 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
|
||
# neo4j-cypher-memory | ||
|
||
This template allows you to have conversations with a Neo4j graph database in natural language, using an OpenAI LLM. | ||
It transforms a natural language question into a Cypher query (used to fetch data from Neo4j databases), executes the query, and provides a natural language response based on the query results. | ||
Additionally, it features a conversational memory module that stores the dialogue history in the Neo4j graph database. | ||
The conversation memory is uniquely maintained for each user session, ensuring personalized interactions. | ||
To facilitate this, please supply both the `user_id` and `session_id` when using the conversation chain. | ||
|
||
## Environment Setup | ||
|
||
Define the following environment variables: | ||
|
||
``` | ||
OPENAI_API_KEY=<YOUR_OPENAI_API_KEY> | ||
NEO4J_URI=<YOUR_NEO4J_URI> | ||
NEO4J_USERNAME=<YOUR_NEO4J_USERNAME> | ||
NEO4J_PASSWORD=<YOUR_NEO4J_PASSWORD> | ||
``` | ||
|
||
## Neo4j database setup | ||
|
||
There are a number of ways to set up a Neo4j database. | ||
|
||
### Neo4j Aura | ||
|
||
Neo4j AuraDB is a fully managed cloud graph database service. | ||
Create a free instance on [Neo4j Aura](https://neo4j.com/cloud/platform/aura-graph-database?utm_source=langchain&utm_content=langserve). | ||
When you initiate a free database instance, you'll receive credentials to access the database. | ||
|
||
## Populating with data | ||
|
||
If you want to populate the DB with some example data, you can run `python ingest.py`. | ||
This script will populate the database with sample movie data. | ||
|
||
## Usage | ||
|
||
To use this package, you should first have the LangChain CLI installed: | ||
|
||
```shell | ||
pip install -U langchain-cli | ||
``` | ||
|
||
To create a new LangChain project and install this as the only package, you can do: | ||
|
||
```shell | ||
langchain app new my-app --package neo4j-cypher-memory | ||
``` | ||
|
||
If you want to add this to an existing project, you can just run: | ||
|
||
```shell | ||
langchain app add neo4j-cypher-memory | ||
``` | ||
|
||
And add the following code to your `server.py` file: | ||
```python | ||
from neo4j_cypher_memory import chain as neo4j_cypher_memory_chain | ||
|
||
add_routes(app, neo4j_cypher_memory_chain, path="/neo4j-cypher-memory") | ||
``` | ||
|
||
(Optional) Let's now configure LangSmith. | ||
LangSmith will help us trace, monitor and debug LangChain applications. | ||
LangSmith is currently in private beta, you can sign up [here](https://smith.langchain.com/). | ||
If you don't have access, you can skip this section | ||
|
||
```shell | ||
export LANGCHAIN_TRACING_V2=true | ||
export LANGCHAIN_API_KEY=<your-api-key> | ||
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default" | ||
``` | ||
|
||
If you are inside this directory, then you can spin up a LangServe instance directly by: | ||
|
||
```shell | ||
langchain serve | ||
``` | ||
|
||
This will start the FastAPI app with a server is running locally at | ||
[http://localhost:8000](http://localhost:8000) | ||
|
||
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs) | ||
We can access the playground at [http://127.0.0.1:8000/neo4j_cypher_memory/playground](http://127.0.0.1:8000/neo4j_cypher/playground) | ||
|
||
We can access the template from code with: | ||
|
||
```python | ||
from langserve.client import RemoteRunnable | ||
|
||
runnable = RemoteRunnable("http://localhost:8000/neo4j-cypher-memory") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from langchain.graphs import Neo4jGraph | ||
|
||
graph = Neo4jGraph() | ||
|
||
graph.query( | ||
""" | ||
MERGE (m:Movie {name:"Top Gun"}) | ||
WITH m | ||
UNWIND ["Tom Cruise", "Val Kilmer", "Anthony Edwards", "Meg Ryan"] AS actor | ||
MERGE (a:Actor {name:actor}) | ||
MERGE (a)-[:ACTED_IN]->(m) | ||
WITH a | ||
WHERE a.name = "Tom Cruise" | ||
MERGE (a)-[:ACTED_IN]->(:Movie {name:"Mission Impossible"}) | ||
""" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from neo4j_cypher_memory.chain import chain | ||
|
||
if __name__ == "__main__": | ||
original_query = "Who played in Top Gun?" | ||
print( | ||
chain.invoke( | ||
{ | ||
"question": original_query, | ||
"user_id": "user_123", | ||
"session_id": "session_1", | ||
} | ||
) | ||
) | ||
follow_up_query = "Did they play in any other movies?" | ||
print( | ||
chain.invoke( | ||
{ | ||
"question": follow_up_query, | ||
"user_id": "user_123", | ||
"session_id": "session_1", | ||
} | ||
) | ||
) |
3 changes: 3 additions & 0 deletions
3
templates/neo4j-cypher-memory/neo4j_cypher_memory/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from neo4j_cypher_memory.chain import chain | ||
|
||
__all__ = ["chain"] |
146 changes: 146 additions & 0 deletions
146
templates/neo4j-cypher-memory/neo4j_cypher_memory/chain.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from typing import Any, Dict, List | ||
|
||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.graphs import Neo4jGraph | ||
from langchain.memory import ChatMessageHistory | ||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain.pydantic_v1 import BaseModel | ||
from langchain.schema.output_parser import StrOutputParser | ||
from langchain.schema.runnable import RunnablePassthrough | ||
|
||
# Connection to Neo4j | ||
graph = Neo4jGraph() | ||
|
||
# Cypher validation tool for relationship directions | ||
corrector_schema = [ | ||
Schema(el["start"], el["type"], el["end"]) | ||
for el in graph.structured_schema.get("relationships") | ||
] | ||
cypher_validation = CypherQueryCorrector(corrector_schema) | ||
|
||
# LLMs | ||
cypher_llm = ChatOpenAI(model_name="gpt-4", temperature=0.0) | ||
qa_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.0) | ||
|
||
|
||
def convert_messages(input: List[Dict[str, Any]]) -> ChatMessageHistory: | ||
history = ChatMessageHistory() | ||
for item in input: | ||
history.add_user_message(item["result"]["question"]) | ||
history.add_ai_message(item["result"]["answer"]) | ||
return history | ||
|
||
|
||
def get_history(input: Dict[str, Any]) -> ChatMessageHistory: | ||
input.pop("question") | ||
# Lookback conversation window | ||
window = 3 | ||
data = graph.query( | ||
""" | ||
MATCH (u:User {id:$user_id})-[:HAS_SESSION]->(s:Session {id:$session_id}), | ||
(s)-[:LAST_MESSAGE]->(last_message) | ||
MATCH p=(last_message)<-[:NEXT*0..""" | ||
+ str(window) | ||
+ """]-() | ||
WITH p, length(p) AS length | ||
ORDER BY length DESC LIMIT 1 | ||
UNWIND reverse(nodes(p)) AS node | ||
MATCH (node)-[:HAS_ANSWER]->(answer) | ||
RETURN {question:node.text, answer:answer.text} AS result | ||
""", | ||
params=input, | ||
) | ||
history = convert_messages(data) | ||
return history.messages | ||
|
||
|
||
def save_history(input): | ||
input.pop("response") | ||
# store history to database | ||
graph.query( | ||
"""MERGE (u:User {id: $user_id}) | ||
WITH u | ||
OPTIONAL MATCH (u)-[:HAS_SESSION]->(s:Session{id: $session_id}), | ||
(s)-[l:LAST_MESSAGE]->(last_message) | ||
FOREACH (_ IN CASE WHEN last_message IS NULL THEN [1] ELSE [] END | | ||
CREATE (u)-[:HAS_SESSION]->(s1:Session {id:$session_id}), | ||
(s1)-[:LAST_MESSAGE]->(q:Question {text:$question, cypher:$query, date:datetime()}), | ||
(q)-[:HAS_ANSWER]->(:Answer {text:$output})) | ||
FOREACH (_ IN CASE WHEN last_message IS NOT NULL THEN [1] ELSE [] END | | ||
CREATE (last_message)-[:NEXT]->(q:Question | ||
{text:$question, cypher:$query, date:datetime()}), | ||
(q)-[:HAS_ANSWER]->(:Answer {text:$output}), | ||
(s)-[:LAST_MESSAGE]->(q) | ||
DELETE l) """, | ||
params=input, | ||
) | ||
|
||
# Return LLM response to the chain | ||
return input["output"] | ||
|
||
|
||
# Generate Cypher statement based on natural language input | ||
cypher_template = """This is important for my career. | ||
Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question: | ||
{schema} | ||
Question: {question} | ||
Cypher query:""" # noqa: E501 | ||
|
||
cypher_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
( | ||
"system", | ||
"Given an input question, convert it to a Cypher query. No pre-amble.", | ||
), | ||
MessagesPlaceholder(variable_name="history"), | ||
("human", cypher_template), | ||
] | ||
) | ||
|
||
cypher_response = ( | ||
RunnablePassthrough.assign(schema=lambda _: graph.get_schema, history=get_history) | ||
| cypher_prompt | ||
| cypher_llm.bind(stop=["\nCypherResult:"]) | ||
| StrOutputParser() | ||
) | ||
|
||
# Generate natural language response based on database results | ||
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response: | ||
Question: {question} | ||
Cypher query: {query} | ||
Cypher Response: {response}""" # noqa: E501 | ||
|
||
response_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
( | ||
"system", | ||
"Given an input question and Cypher response, convert it to a " | ||
"natural language answer. No pre-amble.", | ||
), | ||
("human", response_template), | ||
] | ||
) | ||
|
||
chain = ( | ||
RunnablePassthrough.assign(query=cypher_response) | ||
| RunnablePassthrough.assign( | ||
response=lambda x: graph.query(cypher_validation(x["query"])), | ||
) | ||
| RunnablePassthrough.assign( | ||
output=response_prompt | qa_llm | StrOutputParser(), | ||
) | ||
| save_history | ||
) | ||
|
||
# Add typing for input | ||
|
||
|
||
class Question(BaseModel): | ||
question: str | ||
user_id: str | ||
session_id: str | ||
|
||
|
||
chain = chain.with_types(input_type=Question) |
Oops, something went wrong.