diff --git a/changelog.md b/changelog.md index 3dc0199..61bb9e5 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,9 @@ ## New features +- Allow to include db node properties in addition to the properties in the GDS Graph. Specify `additional_db_node_properties` in `from_gds`. + + ## Bug fixes - fixed a bug in `from_neo4j`, where the node size would always be set to the `size` property. diff --git a/docs/source/integration.rst b/docs/source/integration.rst index e71609e..3a8a6c2 100644 --- a/docs/source/integration.rst +++ b/docs/source/integration.rst @@ -120,6 +120,7 @@ The default is ``None``, which means that all properties of the nodes in the pro Apart from being visible through on-hover tooltips, these properties could be used to color the nodes, or give captions to them in the visualization, or simply included in the nodes' ``Node.properties`` maps without directly impacting the visualization. +If you want to include node properties stored at the Neo4j database, you can include them in the visualization by using the `additional_db_node_properties` parameter. The last optional property, ``node_radius_min_max``, can be used (and is used by default) to scale the node sizes for the visualization. @@ -401,4 +402,3 @@ In this small example, we import a toy graph representing a social network from For a full example of the ``from_snowflake`` importer in action, please see the :doc:`Visualizing Snowflake Tables tutorial <./tutorials/snowflake-example>`. - diff --git a/python-wrapper/src/neo4j_viz/gds.py b/python-wrapper/src/neo4j_viz/gds.py index 310c621..6297375 100644 --- a/python-wrapper/src/neo4j_viz/gds.py +++ b/python-wrapper/src/neo4j_viz/gds.py @@ -13,11 +13,19 @@ def _fetch_node_dfs( - gds: GraphDataScience, G: Graph, node_properties_by_label: dict[str, list[str]], node_labels: list[str] + gds: GraphDataScience, + G: Graph, + node_properties_by_label: dict[str, list[str]], + node_labels: list[str], + additional_db_node_properties: list[str], ) -> dict[str, pd.DataFrame]: return { lbl: gds.graph.nodeProperties.stream( - G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], separate_property_columns=True + G, + node_properties=node_properties_by_label[lbl], + node_labels=[lbl], + separate_property_columns=True, + db_node_properties=additional_db_node_properties, ) for lbl in node_labels } @@ -49,6 +57,7 @@ def from_gds( G: Graph, size_property: Optional[str] = None, additional_node_properties: Optional[list[str]] = None, + additional_db_node_properties: Optional[list[str]] = None, node_radius_min_max: Optional[tuple[float, float]] = (3, 60), max_node_count: int = 10_000, ) -> VisualizationGraph: @@ -71,7 +80,9 @@ def from_gds( Property to use for node size, by default None. additional_node_properties : list[str], optional Additional properties to include in the visualization node, by default None which means that all node - properties will be fetched. + properties from the Graph will be fetched. + additional_db_node_properties : list[str], optional + Additional node properties to fetch from the database, by default None. Only works if the graph was projected from the database. node_radius_min_max : tuple[float, float], optional Minimum and maximum node radius, by default (3, 60). To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range. @@ -79,6 +90,9 @@ def from_gds( The maximum number of nodes to fetch from the graph. The graph will be sampled using random walk with restarts if its node count exceeds this number. """ + if additional_db_node_properties is None: + additional_db_node_properties = [] + node_properties_from_gds = G.node_properties() assert isinstance(node_properties_from_gds, pd.Series) actual_node_properties: dict[str, list[str]] = cast(dict[str, list[str]], node_properties_from_gds.to_dict()) @@ -102,9 +116,8 @@ def from_gds( } if size_property is not None: - # For some reason mypy are unable to understand that this is dict[str, set[str]] - for label, props in node_properties_by_label_sets.items(): # type: ignore - props.add(size_property) # type: ignore + for label, label_props in node_properties_by_label_sets.items(): + label_props.add(size_property) node_properties_by_label = {k: list(v) for k, v in node_properties_by_label_sets.items()} @@ -129,7 +142,9 @@ def from_gds( for props in node_properties_by_label.values(): props.append(property_name) - node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties_by_label, G_fetched.node_labels()) + node_dfs = _fetch_node_dfs( + gds, G_fetched, node_properties_by_label, G_fetched.node_labels(), additional_db_node_properties + ) if property_name is not None: for df in node_dfs.values(): df.drop(columns=[property_name], inplace=True) diff --git a/python-wrapper/tests/test_gds.py b/python-wrapper/tests/test_gds.py index 3de3822..ce9e1c5 100644 --- a/python-wrapper/tests/test_gds.py +++ b/python-wrapper/tests/test_gds.py @@ -1,5 +1,5 @@ import re -from typing import Any +from typing import Any, Generator import pandas as pd import pytest @@ -8,6 +8,19 @@ from neo4j_viz import Node +@pytest.fixture(scope="class") +def db_setup(gds: Any) -> Generator[None, None, None]: + gds.run_cypher( + "CREATE " + " (a:_CI_A {name:'Alice', height:20, id:42, _id: 1337, caption: 'hello'})" + " ,(b:_CI_A:_CI_B {name:'Bob', height:10, id: 84, size: 11, labels: [1,2]})" + " ,(a)-[:KNOWS {year: 2025, id: 41, source: 1, target: 2}]->(b)" + " ,(b)-[:RELATED {year: 2015, _type: 'A', caption:'hej'}]->(a)" + ) + yield + gds.run_cypher("MATCH (n:_CI_A|_CI_B) DETACH DELETE n") + + @pytest.mark.requires_neo4j_and_gds def test_from_gds_integration_size(gds: Any) -> None: from neo4j_viz.gds import from_gds @@ -74,6 +87,18 @@ def test_from_gds_integration_size(gds: Any) -> None: ] +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.requires_neo4j_and_gds +def test_from_gds_integration_all_db_properties(gds: Any, db_setup: None) -> None: + from neo4j_viz.gds import from_gds + + with gds.graph.project("g2", ["_CI_A", "_CI_B"], "*") as G: + VG = from_gds(gds, G, node_radius_min_max=None, additional_db_node_properties=["name"]) + + assert len(VG.nodes) == 2 + assert {n.properties["name"] for n in VG.nodes} == {"Alice", "Bob"} + + @pytest.mark.requires_neo4j_and_gds def test_from_gds_integration_all_properties(gds: Any) -> None: from neo4j_viz.gds import from_gds