Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions binarytree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from graphviz import Digraph, nohtml
from pkg_resources import get_distribution

from binarytree.exceptions import (
GraphvizImportError,
NodeIndexError,
NodeModifyError,
NodeNotFoundError,
Expand All @@ -18,6 +18,15 @@
TreeHeightError,
)

try:
from graphviz import Digraph, nohtml

GRAPHVIZ_INSTALLED = True
except ImportError:
GRAPHVIZ_INSTALLED = False
Digraph = Any
from binarytree.layout import generate_svg

__version__ = get_distribution("binarytree").version

LEFT_FIELD = "left"
Expand Down Expand Up @@ -463,29 +472,32 @@ def _repr_svg_(self) -> str:

.. _Jupyter notebooks: https://jupyter.org
"""
# noinspection PyProtectedMember
return str(self.graphviz()._repr_svg_())
if GRAPHVIZ_INSTALLED:
# noinspection PyProtectedMember
return str(self.graphviz()._repr_svg_())
else:
return generate_svg(self.values) # pragma: no cover

def graphviz(self, *args: Any, **kwargs: Any) -> Digraph:
"""Return a graphviz.Digraph_ object representing the binary tree.

This method's positional and keyword arguments are passed directly into the
the Digraph's **__init__** method.

:return: graphviz.Digraph_ object representing the binary tree.

:raise binarytree.exceptions.GraphvizImportError: If graphviz is not installed
.. code-block:: python

>>> from binarytree import tree
>>>
>>> t = tree()
>>>
>>> graph = t.graphviz() # Generate a graphviz object
>>> graph.body # Get the DOT body
>>> graph.render() # Render the graph

.. _graphviz.Digraph: https://graphviz.readthedocs.io/en/stable/api.html#digraph
"""
if not GRAPHVIZ_INSTALLED:
raise GraphvizImportError(
"Can't use graphviz method if graphviz module is not installed"
)
if "node_attr" not in kwargs:
kwargs["node_attr"] = {
"shape": "record",
Expand All @@ -494,20 +506,14 @@ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph:
"fillcolor": "lightgray",
"fontcolor": "black",
}

digraph = Digraph(*args, **kwargs)

for node in self:
node_id = str(id(node))

digraph.node(node_id, nohtml(f"<l>|<v> {node.value}|<r>"))

if node.left is not None:
digraph.edge(f"{node_id}:l", f"{id(node.left)}:v")

if node.right is not None:
digraph.edge(f"{node_id}:r", f"{id(node.right)}:v")

return digraph

def pprint(self, index: bool = False, delimiter: str = "-") -> None:
Expand Down
4 changes: 4 additions & 0 deletions binarytree/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ class NodeValueError(BinaryTreeError):

class TreeHeightError(BinaryTreeError):
"""Tree height was invalid."""


class GraphvizImportError(BinaryTreeError):
"""graphviz module is not installed"""
118 changes: 118 additions & 0 deletions binarytree/layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
""" Module containing layout related algorithms."""
from typing import List, Tuple, Union


def _get_coords(
values: List[Union[float, int, None]]
) -> Tuple[
List[Tuple[int, int, Union[float, int, None]]], List[Tuple[int, int, int, int]]
]:
"""Generate the coordinates used for rendering the nodes and edges.

node and edges are stored as tuples in the form node: (x, y, label) and
edge: (x1, y1, x2, y2)

Each coordinate is relative y is the depth, x is the position of the node
on a level from left to right 0 to 2**depth -1

:param values: Values of the binary tree.
:type values: list of ints
:return: nodes and edges list
:rtype: two lists of tuples

"""
x = 0
y = 0
nodes = []
edges = []

# root node
nodes.append((x, y, values[0]))
# append other nodes and their edges
y += 1
for value in values[1:]:
if value is not None:
nodes.append((x, y, value))
edges.append((x // 2, y - 1, x, y))
x += 1
# check if level is full
if x == 2 ** y:
x = 0
y += 1
return nodes, edges


def generate_svg(values: List[Union[float, int, None]]) -> str:
"""Generate a svg image from a binary tree

A simple layout is used based on a perfect tree of same height in which all
leaves would be regularly spaced.

:param values: Values of the binary tree.
:type values: list of ints
:return: the svg image of the tree.
:rtype: str
"""
node_size = 16.0
stroke_width = 1.5
gutter = 0.5
x_scale = (2 + gutter) * node_size
y_scale = 3.0 * node_size

# retrieve relative coordinates
nodes, edges = _get_coords(values)
y_min = min([n[1] for n in nodes])
y_max = max([n[1] for n in nodes])

# generate the svg string
svg = f"""
<svg width="{x_scale * 2**y_max}" height="{y_scale * (2 + y_max)}"
xmlns="http://www.w3.org/2000/svg">
<style>
.bt-label {{
font: 300 {node_size}px sans-serif;;
text-align: center;
dominant-baseline: middle;
text-anchor: middle;
}}
.bt-node {{
fill: lightgray;
stroke-width: {stroke_width};
}}

</style>
<g stroke="#111">
"""
# scales

def scalex(x: int, y: int) -> float:
depth = y_max - y
# offset
x = 2 ** (depth + 1) * x + 2 ** depth - 1
return 1 + node_size + x_scale * x / 2

def scaley(y: int) -> float:
return float(y_scale * (1 + y - y_min))

# edges
def svg_edge(x1: float, y1: float, x2: float, y2: float) -> str:
"""Generate svg code for an edge"""
return f"""<line x1="{x1}" x2="{x2}" y1="{y1}" y2="{y2}"/>"""

for a in edges:
x1, y1, x2, y2 = a
svg += svg_edge(scalex(x1, y1), scaley(y1), scalex(x2, y2), scaley(y2))

# nodes
def svg_node(x: float, y: float, label: str = "") -> str:
"""Generate svg code for a node and his label"""
return f"""
<circle class="bt-node" cx="{x}" cy="{y}" r="{node_size}"/>
<text class="bt-label" x="{x}" y="{y}">{label}</text>"""

for n in nodes:
x, y, label = n
svg += svg_node(scalex(x, y), scaley(y), str(label))

svg += "</g></svg>"
return svg
17 changes: 17 additions & 0 deletions tests/test_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import xml.etree.ElementTree as ET

from binarytree.layout import _get_coords, generate_svg


def test_get_coords():
values = [0, 6, 5, None, 1, 4, 2]
assert _get_coords(values) == (
[(0, 0, 0), (0, 1, 6), (1, 1, 5), (1, 2, 1), (2, 2, 4), (3, 2, 2)],
[(0, 0, 0, 1), (0, 0, 1, 1), (0, 1, 1, 2), (1, 1, 2, 2), (1, 1, 3, 2)],
)


def test_svg():
svg = generate_svg([0, 1, 2])
svg_tree = ET.fromstring(svg)
assert svg_tree.tag == "{http://www.w3.org/2000/svg}svg"