Skip to content

Commit

Permalink
Add option to save/load graph cypher QA (#6219)
Browse files Browse the repository at this point in the history
Similar as #5818

Added the functionality to save/load Graph Cypher QA Chain due to a user
reporting the following error

> raise NotImplementedError("Saving not supported for this chain
type.")\nNotImplementedError: Saving not supported for this chain
type.\n'
  • Loading branch information
tomasonjo committed Jun 19, 2023
1 parent 495128b commit b3bccab
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
4 changes: 4 additions & 0 deletions langchain/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def output_keys(self) -> List[str]:
_output_keys = [self.output_key]
return _output_keys

@property
def _chain_type(self) -> str:
return "graph_cypher_chain"

@classmethod
def from_llm(
cls,
Expand Down
26 changes: 26 additions & 0 deletions langchain/chains/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.base import LLMBashChain
Expand Down Expand Up @@ -416,6 +417,30 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
)


def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:
if "graph" in kwargs:
graph = kwargs.pop("graph")
else:
raise ValueError("`graph` must be present.")
if "cypher_generation_chain" in config:
cypher_generation_chain_config = config.pop("cypher_generation_chain")
cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config)
else:
raise ValueError("`cypher_generation_chain` must be present.")
if "qa_chain" in config:
qa_chain_config = config.pop("qa_chain")
qa_chain = load_chain_from_config(qa_chain_config)
else:
raise ValueError("`qa_chain` must be present.")

return GraphCypherQAChain(
graph=graph,
cypher_generation_chain=cypher_generation_chain,
qa_chain=qa_chain,
**config,
)


def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
if "api_request_chain" in config:
api_request_chain_config = config.pop("api_request_chain")
Expand Down Expand Up @@ -482,6 +507,7 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
"vector_db_qa": _load_vector_db_qa,
"retrieval_qa": _load_retrieval_qa,
"graph_cypher_chain": _load_graph_cypher_chain,
}


Expand Down
29 changes: 29 additions & 0 deletions tests/integration_tests/chains/test_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.loading import load_chain
from langchain.graphs import Neo4jGraph
from langchain.llms.openai import OpenAI

Expand Down Expand Up @@ -168,3 +169,31 @@ def test_cypher_return_direct() -> None:
output = chain.run("Who played in Pulp Fiction?")
expected_output = [{"a.name": "Bruce Willis"}]
assert output == expected_output


def test_cypher_save_load() -> None:
"""Test saving and loading."""

FILE_PATH = "cypher.yaml"

url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None

graph = Neo4jGraph(
url=url,
username=username,
password=password,
)

chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_direct=True
)

chain.save(file_path=FILE_PATH)
qa_loaded = load_chain(FILE_PATH, graph=graph)

assert qa_loaded == chain

0 comments on commit b3bccab

Please sign in to comment.