@@ -254,3 +254,69 @@ def populate_db(db: Node | None, intervals: typing.Sequence[Interval]) -> Node:
254254 for interval in intervals [start_idx :]:
255255 db .insert (interval )
256256 return db
257+
258+
259+ def plot_intersection_tree (tree : Node ) -> None :
260+ """Visualize the intersection tree using :mod:`matplotlib`.
261+
262+ Each node in the tree is drawn as a horizontal line spanning the interval
263+ ``[start, end]``. The root of the tree is shown at the bottom of the
264+ figure, with each subsequent level plotted above it. The start and end
265+ values of the interval are annotated next to their respective end points
266+ together with the ``max_end`` value for that node. Lines are also drawn
267+ from the midpoint of each interval to the midpoints of its children to
268+ illustrate the tree structure.
269+
270+ Parameters
271+ ----------
272+ tree:
273+ Root node of the intersection tree to plot.
274+ """
275+
276+ import matplotlib .pyplot as plt
277+
278+ # Collect all nodes along with their depth in the tree.
279+ nodes : list [tuple [Node , int ]] = []
280+
281+ def _traverse (node : Node | None , depth : int ) -> None :
282+ if node is None :
283+ return
284+ nodes .append ((node , depth ))
285+ _traverse (node ._left , depth + 1 )
286+ _traverse (node ._right , depth + 1 )
287+
288+ _traverse (tree , 0 )
289+
290+ if not nodes :
291+ return
292+
293+ fig , ax = plt .subplots ()
294+ max_depth = max (depth for _ , depth in nodes )
295+
296+ # Draw intervals and record midpoints for connecting lines.
297+ midpoints : dict [Node , tuple [float , int ]] = {}
298+ for node , depth in nodes :
299+ start , end , max_end = node ._start , node ._end , node ._max_end
300+ y = depth
301+ ax .hlines (y , start , end , colors = "tab:blue" )
302+ ax .plot ([start , end ], [y , y ], "o" , color = "tab:blue" , markersize = 3 )
303+ ax .text (start , y + 0.1 , f"{ start } " , ha = "center" , va = "bottom" , fontsize = 8 )
304+ ax .text (end , y + 0.1 , f"{ end } " , ha = "center" , va = "bottom" , fontsize = 8 )
305+ ax .text (end , y - 0.1 , f"max={ max_end } " , ha = "left" , va = "top" , fontsize = 8 )
306+ midpoints [node ] = ((start + end ) / 2 , y )
307+
308+ # Connect each node's midpoint to its children's midpoints.
309+ for node , depth in nodes :
310+ parent_mid , parent_y = midpoints [node ]
311+ if node ._left is not None :
312+ child_mid , child_y = midpoints [node ._left ]
313+ ax .plot ([parent_mid , child_mid ], [parent_y , child_y ], color = "tab:gray" )
314+ if node ._right is not None :
315+ child_mid , child_y = midpoints [node ._right ]
316+ ax .plot ([parent_mid , child_mid ], [parent_y , child_y ], color = "tab:gray" )
317+
318+ ax .set_xlabel ("value" )
319+ ax .set_ylabel ("depth" )
320+ ax .set_ylim (- 1 , max_depth + 1 )
321+ ax .set_title ("Intersection tree" )
322+ plt .show ()
0 commit comments