From c4a25770fa2d6146969c41ea73f1648a7fb31e60 Mon Sep 17 00:00:00 2001 From: Benjamin Abel Date: Thu, 11 Feb 2021 21:21:51 +0100 Subject: [PATCH] Fallback to raw svg if graphviz is not installed --- binarytree/__init__.py | 34 ++++++----- binarytree/exceptions.py | 4 ++ binarytree/layout.py | 118 +++++++++++++++++++++++++++++++++++++++ tests/test_layout.py | 17 ++++++ 4 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 binarytree/layout.py create mode 100644 tests/test_layout.py diff --git a/binarytree/__init__.py b/binarytree/__init__.py index c7d6717..50b93b5 100644 --- a/binarytree/__init__.py +++ b/binarytree/__init__.py @@ -5,10 +5,10 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Tuple, Union -from graphviz import Digraph, nohtml from pkg_resources import get_distribution from binarytree.exceptions import ( + GraphvizImportError, NodeIndexError, NodeModifyError, NodeNotFoundError, @@ -18,6 +18,15 @@ TreeHeightError, ) +try: + from graphviz import Digraph, nohtml + + GRAPHVIZ_INSTALLED = True +except ImportError: + GRAPHVIZ_INSTALLED = False + Digraph = Any + from binarytree.layout import generate_svg + __version__ = get_distribution("binarytree").version LEFT_FIELD = "left" @@ -463,19 +472,19 @@ def _repr_svg_(self) -> str: .. _Jupyter notebooks: https://jupyter.org """ - # noinspection PyProtectedMember - return str(self.graphviz()._repr_svg_()) + if GRAPHVIZ_INSTALLED: + # noinspection PyProtectedMember + return str(self.graphviz()._repr_svg_()) + else: + return generate_svg(self.values) # pragma: no cover def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: """Return a graphviz.Digraph_ object representing the binary tree. - This method's positional and keyword arguments are passed directly into the the Digraph's **__init__** method. - :return: graphviz.Digraph_ object representing the binary tree. - + :raise binarytree.exceptions.GraphvizImportError: If graphviz is not installed .. code-block:: python - >>> from binarytree import tree >>> >>> t = tree() @@ -483,9 +492,12 @@ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: >>> graph = t.graphviz() # Generate a graphviz object >>> graph.body # Get the DOT body >>> graph.render() # Render the graph - .. _graphviz.Digraph: https://graphviz.readthedocs.io/en/stable/api.html#digraph """ + if not GRAPHVIZ_INSTALLED: + raise GraphvizImportError( + "Can't use graphviz method if graphviz module is not installed" + ) if "node_attr" not in kwargs: kwargs["node_attr"] = { "shape": "record", @@ -494,20 +506,14 @@ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: "fillcolor": "lightgray", "fontcolor": "black", } - digraph = Digraph(*args, **kwargs) - for node in self: node_id = str(id(node)) - digraph.node(node_id, nohtml(f"| {node.value}|")) - if node.left is not None: digraph.edge(f"{node_id}:l", f"{id(node.left)}:v") - if node.right is not None: digraph.edge(f"{node_id}:r", f"{id(node.right)}:v") - return digraph def pprint(self, index: bool = False, delimiter: str = "-") -> None: diff --git a/binarytree/exceptions.py b/binarytree/exceptions.py index 0676e24..e5b22ca 100644 --- a/binarytree/exceptions.py +++ b/binarytree/exceptions.py @@ -28,3 +28,7 @@ class NodeValueError(BinaryTreeError): class TreeHeightError(BinaryTreeError): """Tree height was invalid.""" + + +class GraphvizImportError(BinaryTreeError): + """graphviz module is not installed""" diff --git a/binarytree/layout.py b/binarytree/layout.py new file mode 100644 index 0000000..7d200d4 --- /dev/null +++ b/binarytree/layout.py @@ -0,0 +1,118 @@ +""" Module containing layout related algorithms.""" +from typing import List, Tuple, Union + + +def _get_coords( + values: List[Union[float, int, None]] +) -> Tuple[ + List[Tuple[int, int, Union[float, int, None]]], List[Tuple[int, int, int, int]] +]: + """Generate the coordinates used for rendering the nodes and edges. + + node and edges are stored as tuples in the form node: (x, y, label) and + edge: (x1, y1, x2, y2) + + Each coordinate is relative y is the depth, x is the position of the node + on a level from left to right 0 to 2**depth -1 + + :param values: Values of the binary tree. + :type values: list of ints + :return: nodes and edges list + :rtype: two lists of tuples + + """ + x = 0 + y = 0 + nodes = [] + edges = [] + + # root node + nodes.append((x, y, values[0])) + # append other nodes and their edges + y += 1 + for value in values[1:]: + if value is not None: + nodes.append((x, y, value)) + edges.append((x // 2, y - 1, x, y)) + x += 1 + # check if level is full + if x == 2 ** y: + x = 0 + y += 1 + return nodes, edges + + +def generate_svg(values: List[Union[float, int, None]]) -> str: + """Generate a svg image from a binary tree + + A simple layout is used based on a perfect tree of same height in which all + leaves would be regularly spaced. + + :param values: Values of the binary tree. + :type values: list of ints + :return: the svg image of the tree. + :rtype: str + """ + node_size = 16.0 + stroke_width = 1.5 + gutter = 0.5 + x_scale = (2 + gutter) * node_size + y_scale = 3.0 * node_size + + # retrieve relative coordinates + nodes, edges = _get_coords(values) + y_min = min([n[1] for n in nodes]) + y_max = max([n[1] for n in nodes]) + + # generate the svg string + svg = f""" + + + + """ + # scales + + def scalex(x: int, y: int) -> float: + depth = y_max - y + # offset + x = 2 ** (depth + 1) * x + 2 ** depth - 1 + return 1 + node_size + x_scale * x / 2 + + def scaley(y: int) -> float: + return float(y_scale * (1 + y - y_min)) + + # edges + def svg_edge(x1: float, y1: float, x2: float, y2: float) -> str: + """Generate svg code for an edge""" + return f"""""" + + for a in edges: + x1, y1, x2, y2 = a + svg += svg_edge(scalex(x1, y1), scaley(y1), scalex(x2, y2), scaley(y2)) + + # nodes + def svg_node(x: float, y: float, label: str = "") -> str: + """Generate svg code for a node and his label""" + return f""" + + {label}""" + + for n in nodes: + x, y, label = n + svg += svg_node(scalex(x, y), scaley(y), str(label)) + + svg += "" + return svg diff --git a/tests/test_layout.py b/tests/test_layout.py new file mode 100644 index 0000000..0a92258 --- /dev/null +++ b/tests/test_layout.py @@ -0,0 +1,17 @@ +import xml.etree.ElementTree as ET + +from binarytree.layout import _get_coords, generate_svg + + +def test_get_coords(): + values = [0, 6, 5, None, 1, 4, 2] + assert _get_coords(values) == ( + [(0, 0, 0), (0, 1, 6), (1, 1, 5), (1, 2, 1), (2, 2, 4), (3, 2, 2)], + [(0, 0, 0, 1), (0, 0, 1, 1), (0, 1, 1, 2), (1, 1, 2, 2), (1, 1, 3, 2)], + ) + + +def test_svg(): + svg = generate_svg([0, 1, 2]) + svg_tree = ET.fromstring(svg) + assert svg_tree.tag == "{http://www.w3.org/2000/svg}svg"