diff --git a/source_code/intersection_trees/intersection_tree.py b/source_code/intersection_trees/intersection_tree.py index 437e6fb..6765f98 100644 --- a/source_code/intersection_trees/intersection_tree.py +++ b/source_code/intersection_trees/intersection_tree.py @@ -254,3 +254,69 @@ def populate_db(db: Node | None, intervals: typing.Sequence[Interval]) -> Node: for interval in intervals[start_idx:]: db.insert(interval) return db + + +def plot_intersection_tree(tree: Node) -> None: + """Visualize the intersection tree using :mod:`matplotlib`. + + Each node in the tree is drawn as a horizontal line spanning the interval + ``[start, end]``. The root of the tree is shown at the bottom of the + figure, with each subsequent level plotted above it. The start and end + values of the interval are annotated next to their respective end points + together with the ``max_end`` value for that node. Lines are also drawn + from the midpoint of each interval to the midpoints of its children to + illustrate the tree structure. + + Parameters + ---------- + tree: + Root node of the intersection tree to plot. + """ + + import matplotlib.pyplot as plt + + # Collect all nodes along with their depth in the tree. + nodes: list[tuple[Node, int]] = [] + + def _traverse(node: Node | None, depth: int) -> None: + if node is None: + return + nodes.append((node, depth)) + _traverse(node._left, depth + 1) + _traverse(node._right, depth + 1) + + _traverse(tree, 0) + + if not nodes: + return + + fig, ax = plt.subplots() + max_depth = max(depth for _, depth in nodes) + + # Draw intervals and record midpoints for connecting lines. + midpoints: dict[Node, tuple[float, int]] = {} + for node, depth in nodes: + start, end, max_end = node._start, node._end, node._max_end + y = depth + ax.hlines(y, start, end, colors="tab:blue") + ax.plot([start, end], [y, y], "o", color="tab:blue", markersize=3) + ax.text(start, y + 0.1, f"{start}", ha="center", va="bottom", fontsize=8) + ax.text(end, y + 0.1, f"{end}", ha="center", va="bottom", fontsize=8) + ax.text(end, y - 0.1, f"max={max_end}", ha="left", va="top", fontsize=8) + midpoints[node] = ((start + end) / 2, y) + + # Connect each node's midpoint to its children's midpoints. + for node, depth in nodes: + parent_mid, parent_y = midpoints[node] + if node._left is not None: + child_mid, child_y = midpoints[node._left] + ax.plot([parent_mid, child_mid], [parent_y, child_y], color="tab:gray") + if node._right is not None: + child_mid, child_y = midpoints[node._right] + ax.plot([parent_mid, child_mid], [parent_y, child_y], color="tab:gray") + + ax.set_xlabel("value") + ax.set_ylabel("depth") + ax.set_ylim(-1, max_depth + 1) + ax.set_title("Intersection tree") + plt.show()