Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,11 @@

## New features

- Add convenience method `add_data` and `remove_data` to `GraphWidget`.
- Added a selection button to the toolbar.
- Added a layout button to the toolbar if `VG.render_widget` is used.
- Support the new circular layout.

## Bug fixes

- Fixed a bug with the theme detection inn VSCode.

## Improvements

- Allow setting the theme manually in `VG.render(theme="light")` and `VG.render_widget(theme="dark")`.
- Use typed nodes and relationship traitlets in GraphWidget, i.e., list of Node and Relationship instead of dictionaries.
- `render` now allows to pass `layout` as a string as well. Previously expected to be a typed `neo4j_viz.Layout`.
- Fixed rendering in Marimo notebooks
- Support `neo4j.EagerResult` in the `from_neo4j` integration which is the default return type by `neo4j.Driver.execute_query()`.


## Other changes
7 changes: 3 additions & 4 deletions docs/antora/modules/ROOT/pages/integration/neo4j.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pip install neo4j-viz[neo4j]
Once you have installed the additional dependency, you can use the link:{api-docs-uri}/from_neo4j[`from_neo4j`] method
to import query results from Neo4j.

The `from_neo4j` method takes one mandatory positional parameter: A `data` argument representing either a query result in the shape of a `neo4j.graph.Graph` or `neo4j.Result`, or a `neo4j.Driver` in which case a simple default query will be executed internally to retrieve the graph data.
The `from_neo4j` method takes one mandatory positional parameter: A `data` argument representing either a query result in the shape of a `neo4j.graph.Graph`, `neo4j.EagerResult` or `neo4j.Result`, or a `neo4j.Driver` in which case a simple default query will be executed internally to retrieve the graph data.

The optional `max_rows` parameter can be used to limit the number of relationships shown in the visualization.
By default, it is set to 10.000, meaning that if the database has more than 10.000 rows, a warning will be raised.
Expand All @@ -35,11 +35,10 @@ with GraphDatabase.driver(URI, auth=auth) as driver:
result = driver.execute_query(
"MATCH (n)-[r]->(m) RETURN n,r,m",
database_="neo4j",
routing_=RoutingControl.READ,
result_transformer_=Result.graph,
routing_=RoutingControl.READ
)

VG = from_neo4j(result)
----

See the link:{tutorials-docs-uri}/neo4j-example[Visualizing Neo4j Graphs tutorial] for a more extensive example.
See the link:{tutorials-docs-uri}/neo4j-example[Visualizing Neo4j Graphs tutorial] for a more extensive example.
84 changes: 73 additions & 11 deletions python-wrapper/src/neo4j_viz/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Union

import neo4j.graph
from neo4j import Driver, Result, RoutingControl
from neo4j import Driver, EagerResult, Result, RoutingControl
from pydantic import BaseModel, ValidationError

from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace
Expand All @@ -21,12 +21,57 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) ->
)


def find_graph_entity(value: object) -> neo4j.graph.Graph | None:
"""Recursively traverse lists and dicts to find a Graph entity."""
if isinstance(value, neo4j.graph.Entity):
return value.graph
elif isinstance(value, neo4j.graph.Path):
return value.graph
elif isinstance(value, list):
for item in value:
G = find_graph_entity(item)
if G:
return G
elif isinstance(value, dict):
for item in value.values():
G = find_graph_entity(item)
if G:
return G
return None


def _graph_from_eager_result(data: EagerResult) -> neo4j.graph.Graph:
"""Return the bolt hydration Graph shared by all entities in an EagerResult.

Every Node/Relationship produced by the same query references the same
internal Graph that the driver built during bolt hydration — identical to
what Result.graph() returns. We find the first entity in the records and
return its .graph. If the result contains no graph entities at all we fall
back to walking the records manually.
"""
for record in data.records:
for value in record.values():
if isinstance(value, (neo4j.graph.Node, neo4j.graph.Relationship)):
return value.graph
if isinstance(value, neo4j.graph.Path) and value.nodes:
return value.nodes[0].graph

# Fallback: no direct entity columns — walk everything recursively and return the first graph we find.
for record in data.records:
for value in record.values():
G = find_graph_entity(value)
if G:
return G

raise ValueError("No graph entities found in eager result")


def from_neo4j(
data: Union[neo4j.graph.Graph, Result, Driver],
data: Union[neo4j.graph.Graph, Result, EagerResult, Driver],
Comment thread
FlorentinD marked this conversation as resolved.
row_limit: int = 10_000,
) -> VisualizationGraph:
"""
Create a VisualizationGraph from a Neo4j `Graph`, Neo4j `Result` or Neo4j `Driver`.
Create a VisualizationGraph from a Neo4j `Graph`, Neo4j `Result`, Neo4j `EagerResult` or Neo4j `Driver`.

By default:

Expand All @@ -39,18 +84,30 @@ def from_neo4j(

Parameters
----------
data : Union[neo4j.graph.Graph, neo4j.Result, neo4j.Driver]
Either a query result in the shape of a `neo4j.graph.Graph` or `neo4j.Result`, or a `neo4j.Driver` in
which case a simple default query will be executed internally to retrieve the graph data.
data : Union[neo4j.graph.Graph, neo4j.Result, neo4j.EagerResult, neo4j.Driver]
Either a query result in the shape of a `neo4j.graph.Graph`, `neo4j.Result`, or `neo4j.EagerResult`
(as returned by `driver.execute_query()`), or a `neo4j.Driver` in which case a simple default query
will be executed internally to retrieve the graph data.
row_limit : int, optional
Maximum number of rows to return from the query, by default 10_000.
This is only used if a `neo4j.Driver` is passed as `result` argument, otherwise the limit is ignored.
"""

if isinstance(data, Result):
graph = data.graph()
raw_nodes = graph.nodes
raw_relationships = graph.relationships
elif isinstance(data, neo4j.graph.Graph):
graph = data
raw_nodes = data.nodes
raw_relationships = data.relationships
elif isinstance(data, EagerResult):
# Every Node/Relationship hydrated from the same query shares one Graph
# object (the bolt hydration graph). Grabbing it from the first entity
# gives us the complete graph — including start/end nodes of
# relationships that were never returned as explicit columns.
graph = _graph_from_eager_result(data)
raw_nodes = graph.nodes
raw_relationships = graph.relationships
elif isinstance(data, Driver):
rel_count = data.execute_query(
"MATCH ()-[r]->() RETURN count(r) as count",
Expand All @@ -62,18 +119,23 @@ def from_neo4j(
f"Database relationship count ({rel_count}) exceeds `row_limit` ({row_limit}), so limiting will be applied. Increase the `row_limit` if needed"
)
graph = data.execute_query(
f"MATCH (n)-[r]->(m) RETURN n,r,m LIMIT {row_limit}",
"MATCH (n)-[r]->(m) RETURN n,r,m LIMIT $rowLimit",
routing_=RoutingControl.READ,
parameters_={"rowLimit": row_limit},
result_transformer_=Result.graph,
)
raw_nodes = graph.nodes
raw_relationships = graph.relationships
else:
raise ValueError(f"Invalid input type `{type(data)}`. Expected `neo4j.Graph`, `neo4j.Result` or `neo4j.Driver`")
raise ValueError(
f"Invalid input type `{type(data)}`. Expected `neo4j.Graph`, `neo4j.Result`, `neo4j.EagerResult` or `neo4j.Driver`"
)

nodes = [_map_node(node) for node in graph.nodes]
nodes = [_map_node(node) for node in raw_nodes]

relationships = []

for rel in graph.relationships:
for rel in raw_relationships:
mapped_rel = _map_relationship(rel)
if mapped_rel:
relationships.append(mapped_rel)
Expand Down
110 changes: 110 additions & 0 deletions python-wrapper/tests/test_find_graph_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import neo4j.graph

from neo4j_viz.neo4j import find_graph_entity


def _make_graph() -> neo4j.graph.Graph:
return neo4j.graph.Graph()


def _make_node(
graph: neo4j.graph.Graph, element_id: str, labels: list[str], props: dict[str, object]
) -> neo4j.graph.Node:
return neo4j.graph.Node(graph, element_id, hash(element_id), labels, props)


def _make_rel(
graph: neo4j.graph.Graph,
element_id: str,
rel_type: str,
start: neo4j.graph.Node,
end: neo4j.graph.Node,
props: dict[str, object] | None = None,
) -> neo4j.graph.Relationship:
RelType = graph.relationship_type(rel_type)
rel = RelType.__new__(RelType)
rel.__dict__.update(
{
"_graph": graph,
"_element_id": element_id,
"_id": hash(element_id),
"_properties": props or {},
"_start_node": start,
"_end_node": end,
}
)
return rel


def test_plain_node() -> None:
g = _make_graph()
node = _make_node(g, "n1", ["A"], {"x": 1})
assert find_graph_entity(node) is g


def test_plain_relationship() -> None:
g = _make_graph()
a = _make_node(g, "a", ["A"], {})
b = _make_node(g, "b", ["B"], {})
rel = _make_rel(g, "r1", "KNOWS", a, b)
assert find_graph_entity(rel) is g


def test_path() -> None:
g = _make_graph()
a = _make_node(g, "a", ["A"], {})
b = _make_node(g, "b", ["B"], {})
rel = _make_rel(g, "r1", "KNOWS", a, b)
path = neo4j.graph.Path(a, rel)
assert find_graph_entity(path) is g


def test_list_of_nodes() -> None:
g = _make_graph()
a = _make_node(g, "a", ["A"], {})
b = _make_node(g, "b", ["B"], {})
value = [a, b]
assert find_graph_entity(value) is g


def test_nested_list() -> None:
g = _make_graph()
a = _make_node(g, "a", ["A"], {})
value = [[a]]
assert find_graph_entity(value) is g


def test_dict_of_nodes() -> None:
g = _make_graph()
a = _make_node(g, "a", ["A"], {})
value = {"key": a}
assert find_graph_entity(value) is g


def test_deduplication() -> None:
g = _make_graph()
a = _make_node(g, "a", ["A"], {})
value = [a, a]
assert find_graph_entity(value) is g


def test_scalar_ignored() -> None:
assert find_graph_entity("hello") is None
assert find_graph_entity(42) is None
assert find_graph_entity(None) is None


def test_mixed_list_with_graph_entities_and_scalars() -> None:
g = _make_graph()
a = _make_node(g, "a", ["A"], {})
b = _make_node(g, "b", ["B"], {})
rel = _make_rel(g, "r1", "KNOWS", a, b)
value = ["hello", rel, 42, None]
assert find_graph_entity(value) is g


def test_mixed_dict_with_graph_entities_and_scalars() -> None:
g = _make_graph()
a = _make_node(g, "a", ["A"], {})
value = {"text": "hello", "node": a, "count": 42, "empty": None}
assert find_graph_entity(value) is g
54 changes: 53 additions & 1 deletion python-wrapper/tests/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import neo4j
import pytest
from neo4j import Driver, Session
from neo4j import Driver, EagerResult, Session

from neo4j_viz.colors import NEO4J_COLORS_DISCRETE
from neo4j_viz.neo4j import from_neo4j
Expand Down Expand Up @@ -123,6 +123,58 @@ def test_from_neo4j_result(neo4j_session: Session) -> None:
]


@pytest.mark.requires_neo4j_and_gds
def test_from_neo4j_eager_result(neo4j_session: Session, neo4j_driver: Driver) -> None:
graph = neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph()

eager_result: EagerResult = neo4j_driver.execute_query("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a")
assert isinstance(eager_result, EagerResult)

VG = from_neo4j(eager_result)

sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"])
node_ids: list[str] = [node.element_id for node in sorted_nodes]

expected_nodes = [
Node(
id=node_ids[0],
caption="_CI_A",
color=NEO4J_COLORS_DISCRETE[0],
properties=dict(
labels=["_CI_A"],
name="Alice",
height=20,
id=42,
_id=1337,
caption="hello",
),
),
Node(
id=node_ids[1],
caption="_CI_A:_CI_B",
color=NEO4J_COLORS_DISCRETE[1],
properties=dict(
size=11,
labels=["_CI_A", "_CI_B"],
name="Bob",
height=10,
id=84,
__labels=[1, 2],
),
),
]

assert len(VG.nodes) == 2
assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes

assert len(VG.relationships) == 2
vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo")
assert vg_rels == [
(node_ids[0], node_ids[1], "KNOWS"),
(node_ids[1], node_ids[0], "RELATED"),
]


@pytest.mark.requires_neo4j_and_gds
def test_from_neo4j_graph_driver(neo4j_session: Session, neo4j_driver: Driver) -> None:
graph = neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph()
Expand Down
4 changes: 2 additions & 2 deletions python-wrapper/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading