Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions source_code/intersection_trees/intersection_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,69 @@ def populate_db(db: Node | None, intervals: typing.Sequence[Interval]) -> Node:
for interval in intervals[start_idx:]:
db.insert(interval)
return db


def plot_intersection_tree(tree: Node) -> None:
"""Visualize the intersection tree using :mod:`matplotlib`.

Each node in the tree is drawn as a horizontal line spanning the interval
``[start, end]``. The root of the tree is shown at the bottom of the
figure, with each subsequent level plotted above it. The start and end
values of the interval are annotated next to their respective end points
together with the ``max_end`` value for that node. Lines are also drawn
from the midpoint of each interval to the midpoints of its children to
illustrate the tree structure.

Parameters
----------
tree:
Root node of the intersection tree to plot.
"""

import matplotlib.pyplot as plt

# Collect all nodes along with their depth in the tree.
nodes: list[tuple[Node, int]] = []

def _traverse(node: Node | None, depth: int) -> None:
if node is None:
return
nodes.append((node, depth))
_traverse(node._left, depth + 1)
_traverse(node._right, depth + 1)

_traverse(tree, 0)

if not nodes:
return

fig, ax = plt.subplots()
max_depth = max(depth for _, depth in nodes)

# Draw intervals and record midpoints for connecting lines.
midpoints: dict[Node, tuple[float, int]] = {}
for node, depth in nodes:
start, end, max_end = node._start, node._end, node._max_end
y = depth
ax.hlines(y, start, end, colors="tab:blue")
ax.plot([start, end], [y, y], "o", color="tab:blue", markersize=3)
ax.text(start, y + 0.1, f"{start}", ha="center", va="bottom", fontsize=8)
ax.text(end, y + 0.1, f"{end}", ha="center", va="bottom", fontsize=8)
ax.text(end, y - 0.1, f"max={max_end}", ha="left", va="top", fontsize=8)
midpoints[node] = ((start + end) / 2, y)

# Connect each node's midpoint to its children's midpoints.
for node, depth in nodes:
parent_mid, parent_y = midpoints[node]
if node._left is not None:
child_mid, child_y = midpoints[node._left]
ax.plot([parent_mid, child_mid], [parent_y, child_y], color="tab:gray")
if node._right is not None:
child_mid, child_y = midpoints[node._right]
ax.plot([parent_mid, child_mid], [parent_y, child_y], color="tab:gray")

ax.set_xlabel("value")
ax.set_ylabel("depth")
ax.set_ylim(-1, max_depth + 1)
ax.set_title("Intersection tree")
plt.show()