diff --git a/python-wrapper/src/neo4j_viz/visualization_graph.py b/python-wrapper/src/neo4j_viz/visualization_graph.py index a21bea3..4576c69 100644 --- a/python-wrapper/src/neo4j_viz/visualization_graph.py +++ b/python-wrapper/src/neo4j_viz/visualization_graph.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Iterable -from typing import Optional +from typing import Any, Hashable, Optional from IPython.display import HTML from pydantic_extra_types.color import Color, ColorType @@ -201,7 +201,8 @@ def color_nodes(self, property: str, colors: Optional[ColorsType] = None, overri Parameters ---------- property: - The property of the nodes to use for coloring. + The property of the nodes to use for coloring. The type of this property must be hashable, or be a + list, set or dict containing only hashable types. colors: The colors to use for the nodes. If a dictionary is given, it should map from property to color. If an iterable is given, the colors are used in order. @@ -238,7 +239,11 @@ def _color_nodes_iter(self, property: str, colors: Iterable[ColorType], override prop_to_color = {} colors_iter = iter(colors) for node in self.nodes: - prop = getattr(node, property) + raw_prop = getattr(node, property) + try: + prop = self._make_hashable(raw_prop) + except ValueError: + raise ValueError(f"Unable to color nodes by unhashable property type '{type(raw_prop)}'") if prop not in prop_to_color: next_color = next(colors_iter, None) @@ -263,3 +268,22 @@ def _color_nodes_iter(self, property: str, colors: Iterable[ColorType], override f"Ran out of colors for property '{property}'. {len(prop_to_color)} colors were needed, but only " f"{len(set(prop_to_color.values()))} were given, so reused colors" ) + + @staticmethod + def _make_hashable(raw_prop: Any) -> Hashable: + prop = raw_prop + if isinstance(raw_prop, list): + prop = tuple(raw_prop) + elif isinstance(raw_prop, set): + prop = frozenset(raw_prop) + elif isinstance(raw_prop, dict): + prop = tuple(sorted(raw_prop.items())) + + try: + hash(prop) + except TypeError: + raise ValueError(f"Unable to convert property '{raw_prop}' of type {type(raw_prop)} to a hashable type") + + assert isinstance(prop, Hashable) + + return prop diff --git a/python-wrapper/tests/test_colors.py b/python-wrapper/tests/test_colors.py index 3c1a477..950d150 100644 --- a/python-wrapper/tests/test_colors.py +++ b/python-wrapper/tests/test_colors.py @@ -103,3 +103,86 @@ def test_color_nodes_default() -> None: assert VG.nodes[1].color == Color(neo4j_colors[1]) assert VG.nodes[2].color == Color(neo4j_colors[1]) assert VG.nodes[3].color == Color(neo4j_colors[2]) + + +def test_color_nodes_lists() -> None: + nodes = [ + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", labels=["Person"]), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", labels=["Product"]), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", labels=["Product"]), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Both", labels=["Person", "Product"]), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:2", caption="Both again", labels=["Person", "Product"]), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:3", caption="Both reorder", labels=["Product", "Person"]), + ] + + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + VG.color_nodes("labels", ["#000000", "#00FF00", "#FF0000", "#0000FF"]) + + assert VG.nodes[0].color == Color("#000000") + assert VG.nodes[1].color == Color("#00ff00") + assert VG.nodes[2].color == Color("#00ff00") + assert VG.nodes[3].color == Color("#ff0000") + assert VG.nodes[4].color == Color("#ff0000") + assert VG.nodes[5].color == Color("#0000ff") + + +def test_color_nodes_sets() -> None: + nodes = [ + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", labels={"Person"}), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", labels={"Product"}), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", labels={"Product"}), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Both", labels={"Person", "Product"}), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:2", caption="Both again", labels={"Person", "Product"}), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:3", caption="Both reorder", labels={"Product", "Person"}), + ] + + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + VG.color_nodes("labels", ["#000000", "#00FF00", "#FF0000", "#0000FF"]) + + assert VG.nodes[0].color == Color("#000000") + assert VG.nodes[1].color == Color("#00ff00") + assert VG.nodes[2].color == Color("#00ff00") + assert VG.nodes[3].color == Color("#ff0000") + assert VG.nodes[4].color == Color("#ff0000") + assert VG.nodes[4].color == Color("#ff0000") + + +def test_color_nodes_dicts() -> None: + nodes = [ + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", config={"age": 18}), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", config={"price": 100}), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", config={"price": 100}), + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Product", config={"price": 1}), + ] + + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + VG.color_nodes("config", ["#000000", "#00FF00", "#FF0000", "#0000FF"]) + + assert VG.nodes[0].color == Color("#000000") + assert VG.nodes[1].color == Color("#00ff00") + assert VG.nodes[2].color == Color("#00ff00") + assert VG.nodes[3].color == Color("#ff0000") + + +def test_color_nodes_unhashable() -> None: + nodes = [ + Node( + id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", + caption="Person", + config={"movies": ["Star Wars", "Star Trek"]}, + ), + ] + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + with pytest.raises(ValueError, match="Unable to color nodes by unhashable property type ''"): + VG.color_nodes("config", ["#000000"]) + + nodes = [ + Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", list_of_lists=[[1, 2], [3, 4]]), + ] + VG = VisualizationGraph(nodes=nodes, relationships=[]) + with pytest.raises(ValueError, match="Unable to color nodes by unhashable property type ''"): + VG.color_nodes("list_of_lists", ["#000000"])