diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cdb942d..f14f845f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## Next +### Fixed + +- Fixed an edge case where the LLM can output a property with type 'map', which was causing errors during import as it is not a valid property type in Neo4j. + + ## 1.9.1 ### Fixed diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index c8bf647f..62ec209f 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import enum +import json import logging from typing import Optional, Any, TypeVar, Generic, Union @@ -391,11 +392,15 @@ def _enforce_properties( ) -> dict[str, Any]: """ Enforce properties: + - Ensure property type: for now, just prevent having invalid property types (e.g. map) - Filter out those that are not in schema (i.e., valid properties) if allowed properties is False. - Check that all required properties are present and not null. """ - filtered_properties = self._filter_properties( + type_safe_properties = self._ensure_property_types( item.properties, + ) + filtered_properties = self._filter_properties( + type_safe_properties, schema_item.properties, schema_item.additional_properties, item.token, # label or type @@ -453,3 +458,19 @@ def _check_required_properties( if filtered_properties.get(req_prop) is None: missing_required_properties.append(req_prop) return missing_required_properties + + def _ensure_property_types( + self, + filtered_properties: dict[str, Any], + ) -> dict[str, Any]: + type_safe_properties = {} + for prop_name, prop_value in filtered_properties.items(): + if isinstance(prop_value, dict): + # just ensure the type will not raise error on insert, while preserving data + type_safe_properties[prop_name] = json.dumps(prop_value, default=str) + continue + + # this is where we could check types of other properties + # but keep it simple for now + type_safe_properties[prop_name] = prop_value + return type_safe_properties diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 8e3855dd..7826cbd1 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -111,6 +111,12 @@ def validate_additional_properties(self) -> Self: ) return self + def property_type_from_name(self, name: str) -> Optional[PropertyType]: + for prop in self.properties: + if prop.name == name: + return prop + return None + class RelationshipType(BaseModel): """ @@ -141,6 +147,12 @@ def validate_additional_properties(self) -> Self: ) return self + def property_type_from_name(self, name: str) -> Optional[PropertyType]: + for prop in self.properties: + if prop.name == name: + return prop + return None + class GraphSchema(DataModel): """This model represents the expected diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index 4aee8949..f9141945 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import datetime from typing import Any, Optional from unittest.mock import ANY, Mock, patch @@ -101,6 +102,44 @@ def test_graph_pruning_filter_properties( assert filtered_properties == expected_filtered_properties +@pytest.mark.parametrize( + "properties, expected_filtered_properties", + [ + ( + # all good, no bad types + { + "name": "John Does", + "age": 25, + "is_active": True, + }, + { + "name": "John Does", + "age": 25, + "is_active": True, + }, + ), + ( + # map must be serialized + { + "age": {"dob": datetime.date(2000, 1, 1), "age_in_2025": 25}, + }, + { + "age": '{"dob": "2000-01-01", "age_in_2025": 25}', + }, + ), + ], +) +def test_graph_pruning_ensure_property_type( + properties: dict[str, Any], + expected_filtered_properties: dict[str, Any], +) -> None: + pruner = GraphPruning() + type_safe_properties = pruner._ensure_property_types( + properties, + ) + assert type_safe_properties == expected_filtered_properties + + @pytest.fixture(scope="module") def node_type_no_properties() -> NodeType: return NodeType(label="Person")