diff --git a/python-wrapper/src/neo4j_viz/neo4j.py b/python-wrapper/src/neo4j_viz/neo4j.py index c6cec8c..72c350e 100644 --- a/python-wrapper/src/neo4j_viz/neo4j.py +++ b/python-wrapper/src/neo4j_viz/neo4j.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Optional, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import neo4j.graph from neo4j import Result @@ -20,6 +21,9 @@ def from_neo4j( """ Create a VisualizationGraph from a Neo4j Graph or Neo4j Result object. + All node and relationship properties will be included in the visualization graph. + If the property names are conflicting with those of `Node` and `Relationship` objects, they will be prefixed + with `__`. Parameters ---------- @@ -59,12 +63,13 @@ def from_neo4j( def _map_node(node: neo4j.graph.Node, size_property: Optional[str], caption_property: Optional[str]) -> Node: + labels = sorted([label for label in node.labels]) + if size_property: size = node.get(size_property) else: size = None - labels = sorted([label for label in node.labels]) if caption_property: if caption_property == "labels": if len(labels) > 0: @@ -74,7 +79,13 @@ def _map_node(node: neo4j.graph.Node, size_property: Optional[str], caption_prop else: caption = str(node.get(caption_property)) - return Node(id=node.element_id, caption=caption, labels=labels, size=size, **{k: v for k, v in node.items()}) + base_node_props = dict(id=node.element_id, caption=caption, labels=labels, size=size) + + protected_props = base_node_props.keys() + additional_node_props = {k: v for k, v in node.items()} + additional_node_props = _rename_protected_props(additional_node_props, protected_props) + + return Node(**base_node_props, **additional_node_props) def _map_relationship(rel: neo4j.graph.Relationship, caption_property: Optional[str]) -> Optional[Relationship]: @@ -89,11 +100,32 @@ def _map_relationship(rel: neo4j.graph.Relationship, caption_property: Optional[ else: caption = None - return Relationship( + base_rel_props = dict( id=rel.element_id, source=rel.start_node.element_id, target=rel.end_node.element_id, - type_=rel.type, + _type=rel.type, caption=caption, - **{k: v for k, v in rel.items()}, ) + + protected_props = base_rel_props.keys() + additional_rel_props = {k: v for k, v in rel.items()} + additional_rel_props = _rename_protected_props(additional_rel_props, protected_props) + + return Relationship( + **base_rel_props, + **additional_rel_props, + ) + + +def _rename_protected_props( + additional_props: dict[str, Any], + protected_props: Iterable[str], +) -> dict[str, Union[str, int, float]]: + for prop in protected_props: + if prop not in additional_props: + continue + + additional_props[f"__{prop}"] = additional_props.pop(prop) + + return additional_props diff --git a/python-wrapper/tests/test_neo4j.py b/python-wrapper/tests/test_neo4j.py index 6ad8062..9c80766 100644 --- a/python-wrapper/tests/test_neo4j.py +++ b/python-wrapper/tests/test_neo4j.py @@ -1,5 +1,6 @@ from typing import Generator +import neo4j import pytest from neo4j import Session @@ -10,7 +11,8 @@ @pytest.fixture(scope="class", autouse=True) def graph_setup(neo4j_session: Session) -> Generator[None, None, None]: neo4j_session.run( - "CREATE (a:_CI_A {name:'Alice', height:20})-[:KNOWS {year: 2025}]->(b:_CI_A:_CI_B {name:'Bob', height:10}), (b)-[:RELATED {year: 2015}]->(a)" + "CREATE (a:_CI_A {name:'Alice', height:20, id:42, _id: 1337, caption: 'hello'})-[:KNOWS {year: 2025, id: 41, source: 1, target: 2}]->" + "(b:_CI_A:_CI_B {name:'Bob', height:10, id: 84, size: 11, labels: [1,2]}), (b)-[:RELATED {year: 2015, _type: 'A', caption:'hej'}]->(a)" ) yield neo4j_session.run("MATCH (n:_CI_A|_CI_B) DETACH DELETE n") @@ -22,11 +24,30 @@ def test_from_neo4j_graph(neo4j_session: Session) -> None: VG = from_neo4j(graph) - node_ids: list[str] = [node.element_id for node in graph.nodes] + sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"]) + node_ids: list[str] = [node.element_id for node in sorted_nodes] expected_nodes = [ - Node(id=node_ids[0], caption="_CI_A", labels=["_CI_A"], name="Alice", height=20), - Node(id=node_ids[1], caption="_CI_A:_CI_B", labels=["_CI_A", "_CI_B"], name="Bob", height=10), + Node( + id=node_ids[0], + caption="_CI_A", + labels=["_CI_A"], + name="Alice", + height=20, + __id=42, + _id=1337, + __caption="hello", + ), + Node( + id=node_ids[1], + caption="_CI_A:_CI_B", + labels=["_CI_A", "_CI_B"], + name="Bob", + height=10, + __id=84, + __size=11, + __labels=[1, 2], + ), ] assert len(VG.nodes) == 2 @@ -47,11 +68,31 @@ def test_from_neo4j_result(neo4j_session: Session) -> None: VG = from_neo4j(result) graph = result.graph() - node_ids: list[str] = [node.element_id for node in graph.nodes] + + sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"]) + node_ids: list[str] = [node.element_id for node in sorted_nodes] expected_nodes = [ - Node(id=node_ids[0], caption="_CI_A", labels=["_CI_A"], name="Alice", height=20), - Node(id=node_ids[1], caption="_CI_A:_CI_B", labels=["_CI_A", "_CI_B"], name="Bob", height=10), + Node( + id=node_ids[0], + caption="_CI_A", + labels=["_CI_A"], + name="Alice", + height=20, + __id=42, + _id=1337, + __caption="hello", + ), + Node( + id=node_ids[1], + caption="_CI_A:_CI_B", + labels=["_CI_A", "_CI_B"], + name="Bob", + height=10, + __id=84, + __size=11, + __labels=[1, 2], + ), ] assert len(VG.nodes) == 2 @@ -71,11 +112,32 @@ def test_from_neo4j_graph_full(neo4j_session: Session) -> None: VG = from_neo4j(graph, node_caption="name", relationship_caption="year", size_property="height") - node_ids: list[str] = [node.element_id for node in graph.nodes] + sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"]) + node_ids: list[str] = [node.element_id for node in sorted_nodes] expected_nodes = [ - Node(id=node_ids[0], caption="Alice", labels=["_CI_A"], name="Alice", height=20, size=60.0), - Node(id=node_ids[1], caption="Bob", labels=["_CI_A", "_CI_B"], name="Bob", height=10, size=3.0), + Node( + id=node_ids[0], + caption="Alice", + labels=["_CI_A"], + name="Alice", + height=20, + size=60.0, + __id=42, + _id=1337, + __caption="hello", + ), + Node( + id=node_ids[1], + caption="Bob", + labels=["_CI_A", "_CI_B"], + name="Bob", + height=10, + size=3.0, + __id=84, + __size=11, + __labels=[1, 2], + ), ] assert len(VG.nodes) == 2