diff --git a/source_code/intersection_trees/intersection_tree.py b/source_code/intersection_trees/intersection_tree.py index 6765f98..6b56d7e 100644 --- a/source_code/intersection_trees/intersection_tree.py +++ b/source_code/intersection_trees/intersection_tree.py @@ -66,16 +66,27 @@ def insert(self, interval: Interval) -> None: if interval[0] >= interval[1]: raise ValueError(f"Invalid interval: start ({interval[0]}) must be less than end ({interval[1]})") + self._insert(interval) + + def _insert(self, interval: Interval) -> None: + '''Private method to insert a new interval [start, end) in the tree. + Does not validate the interval. + + Parameters + ---------- + interval: Interval + the interval to insert (assumed to be valid) + ''' if interval[0] < self._start: if self._left is None: self._left = Node(interval) else: - self._left.insert(interval) + self._left._insert(interval) else: if self._right is None: self._right = Node(interval) else: - self._right.insert(interval) + self._right._insert(interval) self._max_end = max(self._max_end, interval[1]) def search(self, interval: Interval, results: list[Interval]) -> None: @@ -88,13 +99,34 @@ def search(self, interval: Interval, results: list[Interval]) -> None: the interval to search for intersections results: list[Interval] list to append the results to + + Raises + ------ + ValueError + if interval start is not less than end + ''' + if interval[0] >= interval[1]: + raise ValueError(f"Invalid interval: start ({interval[0]}) must be less than end ({interval[1]})") + + self._search(interval, results) + + def _search(self, interval: Interval, results: list[Interval]) -> None: + '''Private method to search for all intervals in the tree that intersect with [start, end) + and append them to results. Does not validate the interval. + + Parameters + ---------- + interval: Interval + the interval to search for intersections (assumed to be valid) + results: list[Interval] + list to append the results to ''' if self._start < interval[1] and interval[0] < self._end: results.append((self._start, self._end)) if self._left is not None and self._left._max_end >= interval[0]: - self._left.search(interval, results) + self._left._search(interval, results) if self._right is not None and self._right._max_end >= interval[0]: - self._right.search(interval, results) + self._right._search(interval, results) def to_str(self, prefix: str = '') -> str: '''Return a string representation of the tree.