Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions python-wrapper/src/neo4j_viz/visualization_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from collections.abc import Iterable
from typing import Optional
from typing import Any, Hashable, Optional

from IPython.display import HTML
from pydantic_extra_types.color import Color, ColorType
Expand Down Expand Up @@ -201,7 +201,8 @@ def color_nodes(self, property: str, colors: Optional[ColorsType] = None, overri
Parameters
----------
property:
The property of the nodes to use for coloring.
The property of the nodes to use for coloring. The type of this property must be hashable, or be a
list, set or dict containing only hashable types.
colors:
The colors to use for the nodes. If a dictionary is given, it should map from property to color.
If an iterable is given, the colors are used in order.
Expand Down Expand Up @@ -238,7 +239,11 @@ def _color_nodes_iter(self, property: str, colors: Iterable[ColorType], override
prop_to_color = {}
colors_iter = iter(colors)
for node in self.nodes:
prop = getattr(node, property)
raw_prop = getattr(node, property)
try:
prop = self._make_hashable(raw_prop)
except ValueError:
raise ValueError(f"Unable to color nodes by unhashable property type '{type(raw_prop)}'")

if prop not in prop_to_color:
next_color = next(colors_iter, None)
Expand All @@ -263,3 +268,22 @@ def _color_nodes_iter(self, property: str, colors: Iterable[ColorType], override
f"Ran out of colors for property '{property}'. {len(prop_to_color)} colors were needed, but only "
f"{len(set(prop_to_color.values()))} were given, so reused colors"
)

@staticmethod
def _make_hashable(raw_prop: Any) -> Hashable:
prop = raw_prop
if isinstance(raw_prop, list):
prop = tuple(raw_prop)
elif isinstance(raw_prop, set):
prop = frozenset(raw_prop)
elif isinstance(raw_prop, dict):
prop = tuple(sorted(raw_prop.items()))

try:
hash(prop)
except TypeError:
raise ValueError(f"Unable to convert property '{raw_prop}' of type {type(raw_prop)} to a hashable type")

assert isinstance(prop, Hashable)

return prop
83 changes: 83 additions & 0 deletions python-wrapper/tests/test_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,86 @@ def test_color_nodes_default() -> None:
assert VG.nodes[1].color == Color(neo4j_colors[1])
assert VG.nodes[2].color == Color(neo4j_colors[1])
assert VG.nodes[3].color == Color(neo4j_colors[2])


def test_color_nodes_lists() -> None:
nodes = [
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", labels=["Person"]),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", labels=["Product"]),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", labels=["Product"]),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Both", labels=["Person", "Product"]),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:2", caption="Both again", labels=["Person", "Product"]),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:3", caption="Both reorder", labels=["Product", "Person"]),
]

VG = VisualizationGraph(nodes=nodes, relationships=[])

VG.color_nodes("labels", ["#000000", "#00FF00", "#FF0000", "#0000FF"])

assert VG.nodes[0].color == Color("#000000")
assert VG.nodes[1].color == Color("#00ff00")
assert VG.nodes[2].color == Color("#00ff00")
assert VG.nodes[3].color == Color("#ff0000")
assert VG.nodes[4].color == Color("#ff0000")
assert VG.nodes[5].color == Color("#0000ff")


def test_color_nodes_sets() -> None:
nodes = [
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", labels={"Person"}),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", labels={"Product"}),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", labels={"Product"}),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Both", labels={"Person", "Product"}),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:2", caption="Both again", labels={"Person", "Product"}),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:3", caption="Both reorder", labels={"Product", "Person"}),
]

VG = VisualizationGraph(nodes=nodes, relationships=[])

VG.color_nodes("labels", ["#000000", "#00FF00", "#FF0000", "#0000FF"])

assert VG.nodes[0].color == Color("#000000")
assert VG.nodes[1].color == Color("#00ff00")
assert VG.nodes[2].color == Color("#00ff00")
assert VG.nodes[3].color == Color("#ff0000")
assert VG.nodes[4].color == Color("#ff0000")
assert VG.nodes[4].color == Color("#ff0000")


def test_color_nodes_dicts() -> None:
nodes = [
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", config={"age": 18}),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", config={"price": 100}),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", config={"price": 100}),
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Product", config={"price": 1}),
]

VG = VisualizationGraph(nodes=nodes, relationships=[])

VG.color_nodes("config", ["#000000", "#00FF00", "#FF0000", "#0000FF"])

assert VG.nodes[0].color == Color("#000000")
assert VG.nodes[1].color == Color("#00ff00")
assert VG.nodes[2].color == Color("#00ff00")
assert VG.nodes[3].color == Color("#ff0000")


def test_color_nodes_unhashable() -> None:
nodes = [
Node(
id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0",
caption="Person",
config={"movies": ["Star Wars", "Star Trek"]},
),
]
VG = VisualizationGraph(nodes=nodes, relationships=[])

with pytest.raises(ValueError, match="Unable to color nodes by unhashable property type '<class 'dict'>'"):
VG.color_nodes("config", ["#000000"])

nodes = [
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", list_of_lists=[[1, 2], [3, 4]]),
]
VG = VisualizationGraph(nodes=nodes, relationships=[])
with pytest.raises(ValueError, match="Unable to color nodes by unhashable property type '<class 'list'>'"):
VG.color_nodes("list_of_lists", ["#000000"])