Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions source_code/intersection_trees/intersection_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,27 @@ def insert(self, interval: Interval) -> None:
if interval[0] >= interval[1]:
raise ValueError(f"Invalid interval: start ({interval[0]}) must be less than end ({interval[1]})")

self._insert(interval)

def _insert(self, interval: Interval) -> None:
'''Private method to insert a new interval [start, end) in the tree.
Does not validate the interval.

Parameters
----------
interval: Interval
the interval to insert (assumed to be valid)
'''
if interval[0] < self._start:
if self._left is None:
self._left = Node(interval)
else:
self._left.insert(interval)
self._left._insert(interval)
else:
if self._right is None:
self._right = Node(interval)
else:
self._right.insert(interval)
self._right._insert(interval)
self._max_end = max(self._max_end, interval[1])

def search(self, interval: Interval, results: list[Interval]) -> None:
Expand All @@ -88,13 +99,34 @@ def search(self, interval: Interval, results: list[Interval]) -> None:
the interval to search for intersections
results: list[Interval]
list to append the results to

Raises
------
ValueError
if interval start is not less than end
'''
if interval[0] >= interval[1]:
raise ValueError(f"Invalid interval: start ({interval[0]}) must be less than end ({interval[1]})")

self._search(interval, results)

def _search(self, interval: Interval, results: list[Interval]) -> None:
'''Private method to search for all intervals in the tree that intersect with [start, end)
and append them to results. Does not validate the interval.

Parameters
----------
interval: Interval
the interval to search for intersections (assumed to be valid)
results: list[Interval]
list to append the results to
'''
if self._start < interval[1] and interval[0] < self._end:
results.append((self._start, self._end))
if self._left is not None and self._left._max_end >= interval[0]:
self._left.search(interval, results)
self._left._search(interval, results)
if self._right is not None and self._right._max_end >= interval[0]:
self._right.search(interval, results)
self._right._search(interval, results)

def to_str(self, prefix: str = '') -> str:
'''Return a string representation of the tree.
Expand Down