Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions python-wrapper/src/neo4j_viz/neo4j.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Optional, Union
from collections.abc import Iterable
from typing import Any, Optional, Union

import neo4j.graph
from neo4j import Result
Expand All @@ -20,6 +21,9 @@ def from_neo4j(
"""
Create a VisualizationGraph from a Neo4j Graph or Neo4j Result object.

All node and relationship properties will be included in the visualization graph.
If the property names are conflicting with those of `Node` and `Relationship` objects, they will be prefixed
with `__`.

Parameters
----------
Expand Down Expand Up @@ -59,12 +63,13 @@ def from_neo4j(


def _map_node(node: neo4j.graph.Node, size_property: Optional[str], caption_property: Optional[str]) -> Node:
labels = sorted([label for label in node.labels])

if size_property:
size = node.get(size_property)
else:
size = None

labels = sorted([label for label in node.labels])
if caption_property:
if caption_property == "labels":
if len(labels) > 0:
Expand All @@ -74,7 +79,13 @@ def _map_node(node: neo4j.graph.Node, size_property: Optional[str], caption_prop
else:
caption = str(node.get(caption_property))

return Node(id=node.element_id, caption=caption, labels=labels, size=size, **{k: v for k, v in node.items()})
base_node_props = dict(id=node.element_id, caption=caption, labels=labels, size=size)

protected_props = base_node_props.keys()
additional_node_props = {k: v for k, v in node.items()}
additional_node_props = _rename_protected_props(additional_node_props, protected_props)

return Node(**base_node_props, **additional_node_props)


def _map_relationship(rel: neo4j.graph.Relationship, caption_property: Optional[str]) -> Optional[Relationship]:
Expand All @@ -89,11 +100,32 @@ def _map_relationship(rel: neo4j.graph.Relationship, caption_property: Optional[
else:
caption = None

return Relationship(
base_rel_props = dict(
id=rel.element_id,
source=rel.start_node.element_id,
target=rel.end_node.element_id,
type_=rel.type,
_type=rel.type,
caption=caption,
**{k: v for k, v in rel.items()},
)

protected_props = base_rel_props.keys()
additional_rel_props = {k: v for k, v in rel.items()}
additional_rel_props = _rename_protected_props(additional_rel_props, protected_props)

return Relationship(
**base_rel_props,
**additional_rel_props,
)


def _rename_protected_props(
additional_props: dict[str, Any],
protected_props: Iterable[str],
) -> dict[str, Union[str, int, float]]:
for prop in protected_props:
if prop not in additional_props:
continue

additional_props[f"__{prop}"] = additional_props.pop(prop)

return additional_props
82 changes: 72 additions & 10 deletions python-wrapper/tests/test_neo4j.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Generator

import neo4j
import pytest
from neo4j import Session

Expand All @@ -10,7 +11,8 @@
@pytest.fixture(scope="class", autouse=True)
def graph_setup(neo4j_session: Session) -> Generator[None, None, None]:
neo4j_session.run(
"CREATE (a:_CI_A {name:'Alice', height:20})-[:KNOWS {year: 2025}]->(b:_CI_A:_CI_B {name:'Bob', height:10}), (b)-[:RELATED {year: 2015}]->(a)"
"CREATE (a:_CI_A {name:'Alice', height:20, id:42, _id: 1337, caption: 'hello'})-[:KNOWS {year: 2025, id: 41, source: 1, target: 2}]->"
"(b:_CI_A:_CI_B {name:'Bob', height:10, id: 84, size: 11, labels: [1,2]}), (b)-[:RELATED {year: 2015, _type: 'A', caption:'hej'}]->(a)"
)
yield
neo4j_session.run("MATCH (n:_CI_A|_CI_B) DETACH DELETE n")
Expand All @@ -22,11 +24,30 @@ def test_from_neo4j_graph(neo4j_session: Session) -> None:

VG = from_neo4j(graph)

node_ids: list[str] = [node.element_id for node in graph.nodes]
sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"])
node_ids: list[str] = [node.element_id for node in sorted_nodes]

expected_nodes = [
Node(id=node_ids[0], caption="_CI_A", labels=["_CI_A"], name="Alice", height=20),
Node(id=node_ids[1], caption="_CI_A:_CI_B", labels=["_CI_A", "_CI_B"], name="Bob", height=10),
Node(
id=node_ids[0],
caption="_CI_A",
labels=["_CI_A"],
name="Alice",
height=20,
__id=42,
_id=1337,
__caption="hello",
),
Node(
id=node_ids[1],
caption="_CI_A:_CI_B",
labels=["_CI_A", "_CI_B"],
name="Bob",
height=10,
__id=84,
__size=11,
__labels=[1, 2],
),
]

assert len(VG.nodes) == 2
Expand All @@ -47,11 +68,31 @@ def test_from_neo4j_result(neo4j_session: Session) -> None:
VG = from_neo4j(result)

graph = result.graph()
node_ids: list[str] = [node.element_id for node in graph.nodes]

sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"])
node_ids: list[str] = [node.element_id for node in sorted_nodes]

expected_nodes = [
Node(id=node_ids[0], caption="_CI_A", labels=["_CI_A"], name="Alice", height=20),
Node(id=node_ids[1], caption="_CI_A:_CI_B", labels=["_CI_A", "_CI_B"], name="Bob", height=10),
Node(
id=node_ids[0],
caption="_CI_A",
labels=["_CI_A"],
name="Alice",
height=20,
__id=42,
_id=1337,
__caption="hello",
),
Node(
id=node_ids[1],
caption="_CI_A:_CI_B",
labels=["_CI_A", "_CI_B"],
name="Bob",
height=10,
__id=84,
__size=11,
__labels=[1, 2],
),
]

assert len(VG.nodes) == 2
Expand All @@ -71,11 +112,32 @@ def test_from_neo4j_graph_full(neo4j_session: Session) -> None:

VG = from_neo4j(graph, node_caption="name", relationship_caption="year", size_property="height")

node_ids: list[str] = [node.element_id for node in graph.nodes]
sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"])
node_ids: list[str] = [node.element_id for node in sorted_nodes]

expected_nodes = [
Node(id=node_ids[0], caption="Alice", labels=["_CI_A"], name="Alice", height=20, size=60.0),
Node(id=node_ids[1], caption="Bob", labels=["_CI_A", "_CI_B"], name="Bob", height=10, size=3.0),
Node(
id=node_ids[0],
caption="Alice",
labels=["_CI_A"],
name="Alice",
height=20,
size=60.0,
__id=42,
_id=1337,
__caption="hello",
),
Node(
id=node_ids[1],
caption="Bob",
labels=["_CI_A", "_CI_B"],
name="Bob",
height=10,
size=3.0,
__id=84,
__size=11,
__labels=[1, 2],
),
]

assert len(VG.nodes) == 2
Expand Down