Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

## New features

- Allow to include db node properties in addition to the properties in the GDS Graph. Specify `additional_db_node_properties` in `from_gds`.


## Bug fixes

- fixed a bug in `from_neo4j`, where the node size would always be set to the `size` property.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ The default is ``None``, which means that all properties of the nodes in the pro
Apart from being visible through on-hover tooltips, these properties could be used to color the nodes, or give captions
to them in the visualization, or simply included in the nodes' ``Node.properties`` maps without directly impacting the
visualization.
If you want to include node properties stored at the Neo4j database, you can include them in the visualization by using the `additional_db_node_properties` parameter.

The last optional property, ``node_radius_min_max``, can be used (and is used by default) to scale the node sizes for
the visualization.
Expand Down Expand Up @@ -401,4 +402,3 @@ In this small example, we import a toy graph representing a social network from

For a full example of the ``from_snowflake`` importer in action, please see the
:doc:`Visualizing Snowflake Tables tutorial <./tutorials/snowflake-example>`.

29 changes: 22 additions & 7 deletions python-wrapper/src/neo4j_viz/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@


def _fetch_node_dfs(
gds: GraphDataScience, G: Graph, node_properties_by_label: dict[str, list[str]], node_labels: list[str]
gds: GraphDataScience,
G: Graph,
node_properties_by_label: dict[str, list[str]],
node_labels: list[str],
additional_db_node_properties: list[str],
) -> dict[str, pd.DataFrame]:
return {
lbl: gds.graph.nodeProperties.stream(
G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], separate_property_columns=True
G,
node_properties=node_properties_by_label[lbl],
node_labels=[lbl],
separate_property_columns=True,
db_node_properties=additional_db_node_properties,
)
for lbl in node_labels
}
Expand Down Expand Up @@ -49,6 +57,7 @@ def from_gds(
G: Graph,
size_property: Optional[str] = None,
additional_node_properties: Optional[list[str]] = None,
additional_db_node_properties: Optional[list[str]] = None,
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
max_node_count: int = 10_000,
) -> VisualizationGraph:
Expand All @@ -71,14 +80,19 @@ def from_gds(
Property to use for node size, by default None.
additional_node_properties : list[str], optional
Additional properties to include in the visualization node, by default None which means that all node
properties will be fetched.
properties from the Graph will be fetched.
additional_db_node_properties : list[str], optional
Additional node properties to fetch from the database, by default None. Only works if the graph was projected from the database.
node_radius_min_max : tuple[float, float], optional
Minimum and maximum node radius, by default (3, 60).
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
max_node_count : int, optional
The maximum number of nodes to fetch from the graph. The graph will be sampled using random walk with restarts
if its node count exceeds this number.
"""
if additional_db_node_properties is None:
additional_db_node_properties = []

node_properties_from_gds = G.node_properties()
assert isinstance(node_properties_from_gds, pd.Series)
actual_node_properties: dict[str, list[str]] = cast(dict[str, list[str]], node_properties_from_gds.to_dict())
Expand All @@ -102,9 +116,8 @@ def from_gds(
}

if size_property is not None:
# For some reason mypy are unable to understand that this is dict[str, set[str]]
for label, props in node_properties_by_label_sets.items(): # type: ignore
props.add(size_property) # type: ignore
for label, label_props in node_properties_by_label_sets.items():
label_props.add(size_property)

node_properties_by_label = {k: list(v) for k, v in node_properties_by_label_sets.items()}

Expand All @@ -129,7 +142,9 @@ def from_gds(
for props in node_properties_by_label.values():
props.append(property_name)

node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties_by_label, G_fetched.node_labels())
node_dfs = _fetch_node_dfs(
gds, G_fetched, node_properties_by_label, G_fetched.node_labels(), additional_db_node_properties
)
if property_name is not None:
for df in node_dfs.values():
df.drop(columns=[property_name], inplace=True)
Expand Down
27 changes: 26 additions & 1 deletion python-wrapper/tests/test_gds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any
from typing import Any, Generator

import pandas as pd
import pytest
Expand All @@ -8,6 +8,19 @@
from neo4j_viz import Node


@pytest.fixture(scope="class")
def db_setup(gds: Any) -> Generator[None, None, None]:
gds.run_cypher(
"CREATE "
" (a:_CI_A {name:'Alice', height:20, id:42, _id: 1337, caption: 'hello'})"
" ,(b:_CI_A:_CI_B {name:'Bob', height:10, id: 84, size: 11, labels: [1,2]})"
" ,(a)-[:KNOWS {year: 2025, id: 41, source: 1, target: 2}]->(b)"
" ,(b)-[:RELATED {year: 2015, _type: 'A', caption:'hej'}]->(a)"
)
yield
gds.run_cypher("MATCH (n:_CI_A|_CI_B) DETACH DELETE n")


@pytest.mark.requires_neo4j_and_gds
def test_from_gds_integration_size(gds: Any) -> None:
from neo4j_viz.gds import from_gds
Expand Down Expand Up @@ -74,6 +87,18 @@ def test_from_gds_integration_size(gds: Any) -> None:
]


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.requires_neo4j_and_gds
def test_from_gds_integration_all_db_properties(gds: Any, db_setup: None) -> None:
from neo4j_viz.gds import from_gds

with gds.graph.project("g2", ["_CI_A", "_CI_B"], "*") as G:
VG = from_gds(gds, G, node_radius_min_max=None, additional_db_node_properties=["name"])

assert len(VG.nodes) == 2
assert {n.properties["name"] for n in VG.nodes} == {"Alice", "Bob"}


@pytest.mark.requires_neo4j_and_gds
def test_from_gds_integration_all_properties(gds: Any) -> None:
from neo4j_viz.gds import from_gds
Expand Down