Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## Next

### Added

- Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j.


## 1.10.1

### Added
Expand Down
3 changes: 3 additions & 0 deletions docs/source/user_guide_rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ Weaviate Retrievers
collection="Movies",
id_property_external="neo4j_id",
id_property_neo4j="id",
node_label_neo4j="Document", # optional
)

Internally, this retriever performs the vector search in Weaviate, finds the corresponding node by matching
Expand Down Expand Up @@ -795,6 +796,7 @@ Pinecone Retrievers
index_name="Movies",
id_property_neo4j="id",
embedder=embedder,
node_label_neo4j="Document", # optional
)

Also see :ref:`pineconeneo4jretriever`.
Expand Down Expand Up @@ -825,6 +827,7 @@ Qdrant Retrievers
id_property_external="neo4j_id", # The payload field that contains identifier to a corresponding Neo4j node id property
id_property_neo4j="id",
embedder=embedder,
node_label_neo4j="Document", # optional
)

See :ref:`qdrantneo4jretriever`.
Expand Down
2 changes: 2 additions & 0 deletions src/neo4j_graphrag/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,12 @@ def __init__(
id_property_external: str,
id_property_neo4j: str,
neo4j_database: Optional[str] = None,
node_label_neo4j: Optional[str] = None,
):
super().__init__(driver)
self.id_property_external = id_property_external
self.id_property_neo4j = id_property_neo4j
self.node_label_neo4j = node_label_neo4j
self.neo4j_database = neo4j_database

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class PineconeNeo4jRetriever(ExternalRetriever):
retrieval_query (str): Cypher query that gets appended.
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`".

Raises:
RetrieverInitializationError: If validation of the input arguments fail.
Expand All @@ -101,6 +102,7 @@ def __init__(
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
node_label_neo4j: Optional[str] = None,
):
try:
driver_model = Neo4jDriverModel(driver=driver)
Expand All @@ -116,6 +118,7 @@ def __init__(
retrieval_query=retrieval_query,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
node_label_neo4j=node_label_neo4j,
)
except ValidationError as e:
raise RetrieverInitializationError(e.errors()) from e
Expand All @@ -125,6 +128,7 @@ def __init__(
id_property_external="id",
id_property_neo4j=validated_data.id_property_neo4j,
neo4j_database=neo4j_database,
node_label_neo4j=node_label_neo4j,
)
self.driver = validated_data.driver_model.driver
self.client = validated_data.client_model.client
Expand Down Expand Up @@ -172,7 +176,8 @@ def get_search_results(
driver=neo4j_driver,
client=pc_client,
index_name="jeopardy",
id_property_neo4j="id"
id_property_neo4j="id",
node_label_neo4j="Document",
)
biology_embedding = ...
retriever.search(query_vector=biology_embedding, top_k=2)
Expand Down Expand Up @@ -223,6 +228,7 @@ def get_search_results(
search_query = get_match_query(
return_properties=self.return_properties,
retrieval_query=self.retrieval_query,
node_label=self.node_label_neo4j,
)

parameters = {
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/retrievers/external/pinecone/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ class PineconeNeo4jRetrieverModel(BaseModel):
retrieval_query: Optional[str] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None
node_label_neo4j: Optional[str] = None
8 changes: 7 additions & 1 deletion src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class QdrantNeo4jRetriever(ExternalRetriever):
return_properties (Optional[list[str]]): List of node properties to return.
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`".

Raises:
RetrieverInitializationError: If validation of the input arguments fail.
Expand All @@ -99,6 +100,7 @@ def __init__(
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
node_label_neo4j: Optional[str] = None,
):
try:
driver_model = Neo4jDriverModel(driver=driver)
Expand All @@ -116,6 +118,7 @@ def __init__(
retrieval_query=retrieval_query,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
node_label_neo4j=node_label_neo4j,
)
except ValidationError as e:
raise RetrieverInitializationError(e.errors()) from e
Expand All @@ -125,6 +128,7 @@ def __init__(
id_property_external=validated_data.id_property_external,
id_property_neo4j=validated_data.id_property_neo4j,
neo4j_database=neo4j_database,
node_label_neo4j=node_label_neo4j,
)
self.driver = validated_data.driver_model.driver
self.client = validated_data.client_model.client
Expand Down Expand Up @@ -169,7 +173,8 @@ def get_search_results(
driver=neo4j_driver,
client=client,
collection_name="my_collection",
id_property_external="neo4j_id"
id_property_external="neo4j_id",
node_label_neo4j="Document",
)
embedding = ...
retriever.search(query_vector=embedding, top_k=2)
Expand Down Expand Up @@ -223,6 +228,7 @@ def get_search_results(
search_query = get_match_query(
return_properties=self.return_properties,
retrieval_query=self.retrieval_query,
node_label=self.node_label_neo4j,
)

parameters = {
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/retrievers/external/qdrant/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ class QdrantNeo4jRetrieverModel(BaseModel):
retrieval_query: Optional[str] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None
node_label_neo4j: Optional[str] = None
10 changes: 8 additions & 2 deletions src/neo4j_graphrag/retrievers/external/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@


def get_match_query(
return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None
return_properties: Optional[list[str]] = None,
retrieval_query: Optional[str] = None,
node_label: Optional[str] = None,
) -> str:
# node_label is not escaped on purpose, allowing users to use any valid
# node label expression, e.g. "Actor|Director". It's up to the user to ensure
# labels are properly escaped, i.e. "`My label with space`".
node_label_expression = f":{node_label}" if node_label else ""
match_query = (
"UNWIND $match_params AS match_param "
"WITH match_param[0] AS match_id_value, match_param[1] AS score "
"MATCH (node) "
f"MATCH (node{node_label_expression}) "
"WHERE node[$id_property] = match_id_value "
)
return match_query + get_query_tail(
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/retrievers/external/weaviate/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel):
retrieval_query: Optional[str] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None
node_label_neo4j: Optional[str] = None


class WeaviateNeo4jSearchModel(VectorSearchModel):
Expand Down
11 changes: 10 additions & 1 deletion src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever):
return_properties (Optional[list[str]]): List of node properties to return.
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`".

Raises:
RetrieverInitializationError: If validation of the input arguments fail.
Expand All @@ -100,6 +101,7 @@ def __init__(
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
node_label_neo4j: Optional[str] = None,
):
try:
driver_model = Neo4jDriverModel(driver=driver)
Expand All @@ -116,12 +118,17 @@ def __init__(
retrieval_query=retrieval_query,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
node_label_neo4j=node_label_neo4j,
)
except ValidationError as e:
raise RetrieverInitializationError(e.errors()) from e

super().__init__(
driver, id_property_external, id_property_neo4j, neo4j_database
driver,
id_property_external,
id_property_neo4j,
neo4j_database,
node_label_neo4j,
)
self.client = validated_data.client_model.client
collection = validated_data.collection
Expand Down Expand Up @@ -164,6 +171,7 @@ def get_search_results(
collection="Jeopardy",
id_property_external="neo4j_id",
id_property_neo4j="id",
node_label_neo4j="Document",
)

biology_embedding = ...
Expand Down Expand Up @@ -234,6 +242,7 @@ def get_search_results(
search_query = get_match_query(
return_properties=self.return_properties,
retrieval_query=self.retrieval_query,
node_label=self.node_label_neo4j,
)

parameters = {
Expand Down
6 changes: 5 additions & 1 deletion tests/e2e/pinecone_e2e/test_pinecone_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def populate_neo4j_db(driver: MagicMock) -> None:
@pytest.mark.usefixtures("populate_neo4j_db")
def test_pinecone_neo4j_vector_input(driver: MagicMock, client: MagicMock) -> None:
retriever = PineconeNeo4jRetriever(
driver=driver, client=client, index_name="jeopardy", id_property_neo4j="id"
driver=driver,
client=client,
index_name="jeopardy",
id_property_neo4j="id",
node_label_neo4j="`Question`",
)
with mock.patch.object(retriever, "index") as mock_index:
top_k = 2
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/qdrant_e2e/test_qdrant_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_qdrant_neo4j_vector_input(driver: Driver, qdrant_client: QdrantClient)
collection_name="Jeopardy",
id_property_external="neo4j_id",
id_property_neo4j="id",
node_label_neo4j="Question",
)

top_k = 1
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/weaviate_e2e/test_weaviate_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_weaviate_neo4j_vector_input(
collection="Jeopardy",
id_property_external="neo4j_id",
id_property_neo4j="id",
node_label_neo4j="Question|Answer",
)

top_k = 2
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/retrievers/external/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,20 @@ def test_match_query_with_both_return_properties_and_retrieval_query() -> None:
assert match_query.strip() == expected.strip()


def test_match_query_with_custom_node_label() -> None:
match_query = get_match_query(
return_properties=["name", "age"], node_label="`MyNodeLabel`"
)
expected = (
"UNWIND $match_params AS match_param "
"WITH match_param[0] AS match_id_value, match_param[1] AS score "
"MATCH (node:`MyNodeLabel`) "
"WHERE node[$id_property] = match_id_value "
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score "
)
assert match_query.strip() == expected.strip()


def test_weaviate_retriever_with_result_format_function(
driver: MagicMock, neo4j_record: MagicMock, result_formatter: MagicMock
) -> None:
Expand Down