Skip to content

Commit 8f8cde2

Browse files
authored
Merge pull request #8 from gjbex/codex/add-matplotlib-visualization-for-intersection-tree
Add intersection tree plotting capability
2 parents 35ee5bf + 78ef53f commit 8f8cde2

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

source_code/intersection_trees/intersection_tree.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,69 @@ def populate_db(db: Node | None, intervals: typing.Sequence[Interval]) -> Node:
254254
for interval in intervals[start_idx:]:
255255
db.insert(interval)
256256
return db
257+
258+
259+
def plot_intersection_tree(tree: Node) -> None:
260+
"""Visualize the intersection tree using :mod:`matplotlib`.
261+
262+
Each node in the tree is drawn as a horizontal line spanning the interval
263+
``[start, end]``. The root of the tree is shown at the bottom of the
264+
figure, with each subsequent level plotted above it. The start and end
265+
values of the interval are annotated next to their respective end points
266+
together with the ``max_end`` value for that node. Lines are also drawn
267+
from the midpoint of each interval to the midpoints of its children to
268+
illustrate the tree structure.
269+
270+
Parameters
271+
----------
272+
tree:
273+
Root node of the intersection tree to plot.
274+
"""
275+
276+
import matplotlib.pyplot as plt
277+
278+
# Collect all nodes along with their depth in the tree.
279+
nodes: list[tuple[Node, int]] = []
280+
281+
def _traverse(node: Node | None, depth: int) -> None:
282+
if node is None:
283+
return
284+
nodes.append((node, depth))
285+
_traverse(node._left, depth + 1)
286+
_traverse(node._right, depth + 1)
287+
288+
_traverse(tree, 0)
289+
290+
if not nodes:
291+
return
292+
293+
fig, ax = plt.subplots()
294+
max_depth = max(depth for _, depth in nodes)
295+
296+
# Draw intervals and record midpoints for connecting lines.
297+
midpoints: dict[Node, tuple[float, int]] = {}
298+
for node, depth in nodes:
299+
start, end, max_end = node._start, node._end, node._max_end
300+
y = depth
301+
ax.hlines(y, start, end, colors="tab:blue")
302+
ax.plot([start, end], [y, y], "o", color="tab:blue", markersize=3)
303+
ax.text(start, y + 0.1, f"{start}", ha="center", va="bottom", fontsize=8)
304+
ax.text(end, y + 0.1, f"{end}", ha="center", va="bottom", fontsize=8)
305+
ax.text(end, y - 0.1, f"max={max_end}", ha="left", va="top", fontsize=8)
306+
midpoints[node] = ((start + end) / 2, y)
307+
308+
# Connect each node's midpoint to its children's midpoints.
309+
for node, depth in nodes:
310+
parent_mid, parent_y = midpoints[node]
311+
if node._left is not None:
312+
child_mid, child_y = midpoints[node._left]
313+
ax.plot([parent_mid, child_mid], [parent_y, child_y], color="tab:gray")
314+
if node._right is not None:
315+
child_mid, child_y = midpoints[node._right]
316+
ax.plot([parent_mid, child_mid], [parent_y, child_y], color="tab:gray")
317+
318+
ax.set_xlabel("value")
319+
ax.set_ylabel("depth")
320+
ax.set_ylim(-1, max_depth + 1)
321+
ax.set_title("Intersection tree")
322+
plt.show()

0 commit comments

Comments
 (0)