In [1]:
import pytest
import random
import numpy as np
import dataclasses
from typing import Any, Iterator, Optional, Union



class Counter:
  """An incrementing and decrementing counter metric."""

  def __init__(self) -> None:
    self._count = 0

  def increase(self, n: int = 1) -> None:
    """Increment the counter.

    Parameters
    ----------
    n: `int`
      The count to be increased.
    """
    self._count += n

  def decrease(self, n: int = 1) -> None:
    """Decrement the counter.

    Parameters
    ----------
    n: `int`
      The count to be decreased.
    """
    self._count -= n

  @property
  def count(self) -> int:
    """Return the current count."""
    return self._count



class Histogram:
  """A metric which calculates the distribution of a value."""

  def __init__(self) -> None:
    self._values: list[int] = list()

  def update(self, value: int) -> None:
    """Add a recorded value.

    Parameters
    ----------
    value: `int`
      value to be updated
    """
    self._values.append(value)

  def report(self) -> dict:
    """Return the histogram report."""
    array = np.array(self._values)
    return {
     "min": array.min(),
     "max": array.max(),
     "medium": np.median(array),
     "mean": array.mean(),
     "stdDev": array.std(),
     "percentile": {
        "75": np.percentile(array, 75),
        "95": np.percentile(array, 95),
        "99": np.percentile(array, 99),
      },
    }



MetricType = Union[Counter, Histogram]
"""Alias for the supported metric types."""

class MetricRegistry:
  """A registry for metric instances."""

  def __init__(self) -> None:
    self._registry: dict[str, MetricType] = dict()

  def register(self, name: str, metric: MetricType) -> None:
    """Given a metric, register it under the given name.

    Parameters
    ----------
    name: `str`
      The name of the metric

    metric: `MetricType`
      The type of the metric
    """
    self._registry[name] = metric

  def get_metric(self, name: str) -> MetricType:
    """Return the metric by the given name.

    Parameters
    ----------
    name: `str`
      The name of the metric

    Returns
    -------
    `MetricType`
      The metric instance by the given name.
    """
    return self._registry[name]



"""Tree Exception Definitions."""

class DuplicateKeyError(Exception):
  """Raised when a key already exists."""

  def __init__(self, key: str) -> None:
    Exception.__init__(self, f"{key} already exists.")



@dataclasses.dataclass
class Node:
  """Binary Search Tree node definition."""
  key: Any
  data: Any
  left: Optional["Node"] = None
  right: Optional["Node"] = None
  parent: Optional["Node"] = None


"""Binary Search Tree."""
class BinarySearchTree:

  """Binary Search Tree.
  Attributes
  ----------
  root: `Optional[Node]`
    The root node of the binary search tree.
  empty: `bool`
    `True` if the tree is empty; `False` otherwise.

  Methods
  -------
  Core Functions
    search(key: `Any`)
      Look for a node based on the given key.
    insert(key: `Any`, data: `Any`)
      Insert a (key, data) pair into a binary tree.
    delete(key: `Any`)
      Delete a node based on the given key from the binary tree.

    Auxiliary Functions
    get_leftmost(node: `Node`)
      Return the node whose key is the smallest from the given subtree.
    get_rightmost(node: `Node` = `None`)
      Return the node whose key is the biggest from the given subtree.
    get_successor(node: `Node`)
      Return the successor node in the in-order order.
    get_predecessor(node: `Node`)
      Return the predecessor node in the in-order order.
    get_height(node: `Optional[Node]`)
      Return the height of the given node.
    """

  def __init__(self, registry: Optional[MetricRegistry] = None) -> None:
    self.root: Optional[Node] = None
    self._metrics_enabled = True if registry else False
    if self._metrics_enabled and registry:
      self._height_histogram = Histogram()
      registry.register(name="bst.height", metric=self._height_histogram)

  def __repr__(self) -> str:
    """Provie the tree representation to visualize its layout."""
    if self.root:
      return (
        f"{type(self)}, root={self.root}, "
        f"tree_height={str(self.get_height(self.root))}"
      )
    return "empty tree"

  @property
  def empty(self) -> bool:
    """bool: `True` if the tree is empty; `False` otherwise.

    Notes
    -----
    The property, `empty`, is read-only.
    """
    return self.root is None

  def search(self, key: Any) -> Optional[Node]:
    """Look for a node by a given key.

    Parameters
    ----------
    key: `Any`
      The key associated with the node.

    Returns
    -------
    `Optional[Node]`
      The node found by the given key.
    If the key does not exist, return `None`.
    """
    return self._search(key=key)

  def _search(self, key: Any) -> Optional[Node]:
    current = self.root

    while current:
      if key < current.key:
        current = current.left
      elif key > current.key:
        current = current.right
      else:
        return current
    return None

  def insert(self, key: Any, data: Any) -> None:
    """Insert a (key, data) pair into the binary search tree.

    Parameters
    ----------
    key: `Any`
      The key associated with the data.

    data: `Any`
      The data to be inserted.

    Raises
    ------
   `DuplicateKeyError`
      Raised if the key to be insted has existed in the tree.
    """
    new_node = Node(key=key, data=data)
    parent: Optional[Node] = None
    current: Optional[Node] = self.root
    while current:
      parent = current
      if new_node.key < current.key:
        current = current.left
      elif new_node.key > current.key:
        current = current.right
      else:
        raise DuplicateKeyError(key=new_node.key)
    new_node.parent = parent
    # If the tree is empty
    if parent is None:
      self.root = new_node
    elif new_node.key < parent.key:
      parent.left = new_node
    else:
      parent.right = new_node

    if self._metrics_enabled and self.root:
      self._height_histogram.update(value=self.get_height(self.root))

  def delete(self, key: Any) -> None:
    """Delete a node according to the given key.

    Parameters
    ---------
    key: `Any`
      The key of the node to be deleted.
    """
    if self.root and (deleting_node := self._search(key=key)):
      # Case 1: no child or Case 2a: only one right child
      if deleting_node.left is None:
        self._transplant(
        deleting_node=deleting_node, replacing_node=deleting_node.right
      )

      # Case 2b: only one left left child
      elif deleting_node.right is None:
        self._transplant(
        deleting_node=deleting_node, replacing_node=deleting_node.left
      )

      # Case 3: two children
      else:
        replacing_node = BinarySearchTree.get_leftmost(node=deleting_node.right)
        # The leftmost node is not the direct child of the deleting node
        if replacing_node.parent != deleting_node:
          self._transplant(
            deleting_node=replacing_node,
            replacing_node=replacing_node.right,
          )
          replacing_node.right = deleting_node.right
          replacing_node.right.parent = replacing_node
        self._transplant(
          deleting_node=deleting_node, replacing_node=replacing_node
        )
        replacing_node.left = deleting_node.left
        replacing_node.left.parent = replacing_node

      if self._metrics_enabled and self.root:
        self._height_histogram.update(value=self.get_height(self.root))

  @staticmethod
  def get_leftmost(node: Node) -> Node:
    """Return the leftmost node from a given subtree.

    The key of the leftmost node is the smallest key in the given subtree.

    Parameters
    ----------
    node: `Node`
    The root of the subtree.

    Returns
    -------
   `Node`
     The node whose key is the smallest from the subtree of the given node.
    """
    current_node = node
    while current_node.left:
      current_node = current_node.left
    return current_node

  @staticmethod
  def get_rightmost(node: Node) -> Node:
    """Return the rightmost node from a given subtree.

    The key of the rightmost node is the biggest key in the given subtree.

    Parameters
    ----------
    node: `Node`
     The root of the subtree.

    Returns
    -------
    `Node`
      The node whose key is the biggest from the subtree of the given node.
    """
    current_node = node
    while current_node.right:
      current_node = current_node.right
    return current_node

  @staticmethod
  def get_successor(node: Node) -> Optional[Node]:
    """Return the successor in the in-order order.

    Parameters
    ----------
    node: `Node`
      The node to get its successor.

    Returns
    -------
   `Optional[Node]`
      The successor node.
    """
    if node.right:  # Case 1: right child is not empty
      return BinarySearchTree.get_leftmost(node=node.right)
    # Case 2: right child is empty
    parent = node.parent
    while parent and (node == parent.right):
      node = parent
      parent = parent.parent
    return parent

  @staticmethod
  def get_predecessor(node: Node) -> Optional[Node]:
    """Return the predecessor in the in-order order.

    Parameters
    ----------
    node: `Node`
      The node to get its predecessor.

    Returns
    -------
    `Optional[Node]`
      The predecessor node.
    """
    if node.left:  # Case 1: left child is not empty
      return BinarySearchTree.get_rightmost(node=node.left)
    # Case 2: left child is empty
    parent = node.parent
    while parent and (node == parent.left):
      node = parent
      parent = parent.parent
    return parent

  @staticmethod
  def get_height(node: Node) -> int:
    """Get the height of the given subtree.

    Parameters
    ----------
    node: `Node`
      The root of the subtree to get its height.

    Returns
    -------
    `int`
      The height of the given subtree. 0 if the subtree has only one node.
    """
    if node.left and node.right:
      return (
        max(
          BinarySearchTree.get_height(node=node.left),
          BinarySearchTree.get_height(node=node.right),
        ) + 1
      )

    if node.left:
      return BinarySearchTree.get_height(node=node.left) + 1

    if node.right:
      return BinarySearchTree.get_height(node=node.right) + 1

    # If reach here, it means the node is a leaf node.
    return 0

  def _transplant(self, deleting_node: Node, replacing_node: Optional[Node]) -> None:
    if deleting_node.parent is None:
      self.root = replacing_node
    elif deleting_node == deleting_node.parent.left:
      deleting_node.parent.left = replacing_node
    else:
      deleting_node.parent.right = replacing_node

    if replacing_node:
      replacing_node.parent = deleting_node.parent


"""Traraversal"""



# Alias for the supported node types. For type checking.
SupportedNode = Union[Node]

SupportedTree = Union[BinarySearchTree]
"""Alias for the supported tree types. For type checking."""

Pairs = Iterator[tuple[Any, Any]]
"""Iterator of Key-Value pairs, yield by traversal functions. For type checking"""


def inorder_traverse(tree: SupportedTree, recursive: bool = True) -> Pairs:
  """Perform In-Order traversal.

  In-order traversal traverses a tree by the order:
  left subtree, current node, right subtree (LDR)

  Parameters
  ----------
  tree : `SupportedTree`
    An instance of the supported binary tree types.
  recursive: `bool`
    Perform traversal recursively or not.

  Yields
  ------
  `Pairs`
    The next (key, data) pair in the in-order traversal.
  """
  if recursive:
    return _inorder_traverse(node=tree.root)
  return _inorder_traverse_non_recursive(root=tree.root)


def preorder_traverse(tree: SupportedTree, recursive: bool = True) -> Pairs:
  """Perform Pre-Order traversal.

  Pre-order traversal traverses a tree by the order:
  current node, left subtree, right subtree (DLR)

  Parameters
  ----------
  tree : `SupportedTree`
    An instance of the supported binary tree types.
  recursive: `bool`
    Perform traversal recursively or not.

  Yields
  ------
  `Pairs`
    The next (key, data) pair in the pre-order traversal.
  """
  if recursive:
    return _preorder_traverse(node=tree.root)
  return _preorder_traverse_non_recursive(root=tree.root)


def postorder_traverse(tree: SupportedTree, recursive: bool = True) -> Pairs:
  """Perform Post-Order traversal.

  Post-order traversal traverses a tree by the order:
    left subtree, right subtree, current node (LRD)

  Parameters
  ----------
  tree : `SupportedTree`
    An instance of the supported binary tree types.
  recursive: `bool`
    Perform traversal recursively or not.

  Yields
  ------
  `Pairs`
    The next (key, data) pair in the post-order traversal.
  """
  if recursive:
    return _postorder_traverse(node=tree.root)
  return _postorder_traverse_non_recursive(root=tree.root)


def reverse_inorder_traverse(tree: SupportedTree, recursive: bool = True) -> Pairs:
  """Perform reversed In-Order traversal.

  Reversed in-order traversal traverses a tree by the order:
    right subtree, current node, left subtree (RNL)

  Parameters
  ----------
  tree : `SupportedTree`
    An instance of the supported binary tree types.
  recursive: `bool`
    Perform traversal recursively or not.

  Yields
  ------
  `Pairs`
    The next (key, data) pair in the reversed in-order traversal.
  """
  if recursive:
    return _reverse_inorder_traverse(node=tree.root)
  return _reverse_inorder_traverse_non_recursive(root=tree.root)


def levelorder_traverse(tree: SupportedTree) -> Pairs:
  """Perform Level-Order traversal.

  Level-order traversal traverses a tree:
    level by level, from left to right, starting from the root node.

  Parameters
  ----------
  tree : `SupportedTree`
    An instance of the supported binary tree types.

  Yields
  ------
  `Pairs`
    The next (key, data) pair in the level-order traversal.
  """
  queue = [tree.root]

  while len(queue) > 0:
    temp = queue.pop(0)
    if temp:
      yield (temp.key, temp.data)
      if temp.left:
        queue.append(temp.left)

      if temp.right:
        queue.append(temp.right)


def _inorder_traverse(node: SupportedNode) -> Pairs:
  if node:
    yield from _inorder_traverse(node.left)
    yield (node.key, node.data)
    yield from _inorder_traverse(node.right)


def _inorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
  if root is None:
    raise StopIteration

  stack = []
  if root.right:
    stack.append(root.right)
  stack.append(root)
  current = root.left

  while True:
    if current:
      if current.right:
        stack.append(current.right)
        stack.append(current)
        current = current.left
        continue
      stack.append(current)
      current = current.left

    else:  # current is None
      if len(stack) > 0:
        current = stack.pop()
        if current.right is None:
          yield (current.key, current.data)
          current = None
          continue
        else:  # current.right is not None
          if len(stack) > 0:
            if current.right == stack[-1]:
              yield (current.key, current.data)
              current = stack.pop() if len(stack) > 0 else None
              continue
            else:  # current.right != stack[-1]:
              # This case means there are more nodes on the right
              # Keep the current and go back to add them.
              continue

      else:  # stack is empty
        break



def _reverse_inorder_traverse(node: SupportedNode) -> Pairs:
  if node:
    yield from _reverse_inorder_traverse(node.right)
    yield (node.key, node.data)
    yield from _reverse_inorder_traverse(node.left)


def _reverse_inorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
  if root is None:
    raise StopIteration

  stack = []
  if root.left:
    stack.append(root.left)
  stack.append(root)
  current = root.right

  while True:
    if current:
      if current.left:
        stack.append(current.left)
        stack.append(current)
        current = current.right
        continue
      stack.append(current)
      current = current.right

    else:  # current is None
      if len(stack) > 0:
        current = stack.pop()
        if current.left is None:
          yield (current.key, current.data)
          current = None
          continue
        else:  # current.right is not None
          if len(stack) > 0:
            if current.left == stack[-1]:
              yield (current.key, current.data)
              current = stack.pop() if len(stack) > 0 else None
              continue
            else:  # current.right != stack[-1]:
              # This case means there are more nodes on the right
              # Keep the current and go back to add them.
              continue

      else:  # stack is empty
        break



def _preorder_traverse(node: SupportedNode) -> Pairs:
  if node:
    yield (node.key, node.data)
    yield from _preorder_traverse(node.left)
    yield from _preorder_traverse(node.right)



def _preorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
  if root is None:
    raise StopIteration

  stack = [root]

  while len(stack) > 0:
    temp = stack.pop()
    yield (temp.key, temp.data)

    # Because stack is FILO, insert right child before left child.
    if temp.right:
      stack.append(temp.right)

    if temp.left:
      stack.append(temp.left)


def _postorder_traverse(node: SupportedNode) -> Pairs:
  if node:
    yield from _postorder_traverse(node.left)
    yield from _postorder_traverse(node.right)
    yield (node.key, node.data)


def _postorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
  if root is None:
    raise StopIteration

  stack = []
  if root.right:
    stack.append(root.right)
  stack.append(root)
  current = root.left

  while True:
    if current:
      if current.right:
        stack.append(current.right)
        stack.append(current)
        current = current.left
        continue
      else:  # current.right is None
        if current.left:
          stack.append(current)
        else:
          yield (current.key, current.data)
        current = current.left

    else:  # current is None
      if len(stack) > 0:
        current = stack.pop()
        if current.right is None:
          yield (current.key, current.data)
          current = None
        else:  # current.right is not None
          if len(stack) > 0:
            if current.right != stack[-1]:
              yield (current.key, current.data)
              current = None
            else:  # current.right == stack[-1]
              temp = stack.pop()
              stack.append(current)
              current = temp
          else:  # stack is empty
            yield (current.key, current.data)
            break
      else:  # stack is empty
        break



"""Test functions"""

def test_simple_case(basic_tree: list) -> None:
  """Test the basic opeartions of a binary search tree."""
  tree = BinarySearchTree()

  print("Before inserting any node:",tree.empty)
  # assert tree.empty

  # 23, 4, 30, 11, 7, 34, 20, 24, 22, 15, 1
  for key, data in basic_tree:
    tree.insert(key=key, data=data)

  print("After inserting first node:",tree.empty)
  # assert tree.empty is False

  with pytest.raises(DuplicateKeyError):
    tree.insert(key=23, data="23")

  print(tree.get_leftmost(node=tree.root))
  # assert tree.get_leftmost(node=tree.root).key == 1
  print(tree.get_leftmost(node=tree.root))
  # assert tree.get_leftmost(node=tree.root).data == "1"
  print(tree.get_rightmost(node=tree.root))
  # assert tree.get_rightmost(node=tree.root).key == 34
  print(tree.get_rightmost(node=tree.root))
  # assert tree.get_rightmost(node=tree.root).data == "34"
  print(tree.search(key=24))
  # assert tree.search(key=24).data == "24"
  print(tree.get_height(node=tree.root))
  # assert tree.get_height(node=tree.root) == 4
  print(tree.get_predecessor(node=tree.root))
  # assert tree.get_predecessor(node=tree.root).key == 22
  temp = tree.search(key=24)
  print(tree.get_predecessor(node=temp))
  # assert tree.get_predecessor(node=temp).key == 23
  print(tree.get_successor(node=tree.root))
  # assert tree.get_successor(node=tree.root).key == 24
  temp = tree.search(key=22)
  print(tree.get_successor(node=temp))
  # assert tree.get_successor(node=temp).key == 23

  tree.delete(key=22)
  tree.delete(key=20)
  tree.delete(key=11)

  print(tree.search(key=22))
  # assert tree.search(key=22) is None


def test_metrics(basic_tree):
  """Test binary search tree with metrics enabled."""
  registry = MetricRegistry()
  tree = BinarySearchTree(registry=registry)

  # 23, 4, 30, 11, 7, 34, 20, 24, 22, 15, 1
  for key, data in basic_tree:
    tree.insert(key=key, data=data)

  print(registry.get_metric(name="bst.height").report())
  # assert registry.get_metric(name="bst.height").report()


def test_empty():
  """Test a tree becomes empty."""
  tree = BinarySearchTree()

  for key in range(10):
    tree.insert(key=key, data=str(key))

  for key in range(10):
    tree.delete(key=key)
  print("After creating bst by inserting 10 nodes:",tree.empty)
  assert tree.empty

  for key in reversed(range(10)):
    tree.insert(key=key, data=str(key))

  for key in reversed(range(10)):
    tree.delete(key=key)

  print("After deleting:",tree.empty)
  assert tree.empty

def test_binary_search_tree_traversal_random():
  """Test binary search tree traversal with random sampling."""
  for _ in range(0, 10):
    insert_data = random.sample(range(1, 2000), 500)

  tree = BinarySearchTree()
  for key in insert_data:
    tree.insert(key=key, data=str(key))

  preorder = [item for item in preorder_traverse(tree, False)]
  preorder_recursive = [item for item in preorder_traverse(tree, True)]
  print("Non-recursive preorder:",preorder)
  print("Recursive preorder:",preorder_recursive)
  assert preorder_recursive == preorder

  inorder = [item for item in inorder_traverse(tree, False)]
  inorder_recursive = [item for item in inorder_traverse(tree, True)]
  print("Non-recursive inorder:",inorder)
  print("Recursive inorder:",inorder_recursive)
  assert inorder_recursive == inorder

  rinorder = [item for item in reverse_inorder_traverse(tree, False)]
  rinorder_recursive = [item for item in reverse_inorder_traverse(tree, True)]
  print("Non-recursive reverse inorder:",rinorder)
  print("Recursive reverse inorder:",rinorder_recursive)
  assert rinorder_recursive == rinorder

  postorder_recursive = [item for item in postorder_traverse(tree, True)]
  postorder = [item for item in postorder_traverse(tree, False)]
  print("Non-recursive postorder:",postorder)
  print("Recursive postorder:",postorder_recursive)
  assert postorder_recursive == postorder

In [2]:
test_simple_case([(23,"23"),
                  (4,"4"),
                  (30,"30"),
                  (11,"11"),
                  (7,"7"),
                  (34,"34"),
                  (20,"20"),
                  (24,"24"),
                  (22,"22"),
                  (15,"15"),
                  (1,"1"),
                  ])

Before inserting any node: True
After inserting first node: False
Node(key=1, data='1', left=None, right=None, parent=Node(key=4, data='4', left=..., right=Node(key=11, data='11', left=Node(key=7, data='7', left=None, right=None, parent=...), right=Node(key=20, data='20', left=Node(key=15, data='15', left=None, right=None, parent=...), right=Node(key=22, data='22', left=None, right=None, parent=...), parent=...), parent=...), parent=Node(key=23, data='23', left=..., right=Node(key=30, data='30', left=Node(key=24, data='24', left=None, right=None, parent=...), right=Node(key=34, data='34', left=None, right=None, parent=...), parent=...), parent=None)))
Node(key=1, data='1', left=None, right=None, parent=Node(key=4, data='4', left=..., right=Node(key=11, data='11', left=Node(key=7, data='7', left=None, right=None, parent=...), right=Node(key=20, data='20', left=Node(key=15, data='15', left=None, right=None, parent=...), right=Node(key=22, data='22', left=None, right=None, parent=...), pa

In [3]:
test_metrics([(23,"23"),
                  (4,"4"),
                  (30,"30"),
                  (11,"11"),
                  (7,"7"),
                  (34,"34"),
                  (20,"20"),
                  (24,"24"),
                  (22,"22"),
                  (15,"15"),
                  (1,"1"),
                  ])

{'min': 0, 'max': 4, 'medium': 3.0, 'mean': 2.5454545454545454, 'stdDev': 1.304790917673393, 'percentile': {'75': 3.5, '95': 4.0, '99': 4.0}}


In [4]:
test_empty()

After creating bst by inserting 10 nodes: True
After deleting: True


In [5]:
test_binary_search_tree_traversal_random()

Non-recursive preorder: [(430, '430'), (306, '306'), (285, '285'), (135, '135'), (2, '2'), (110, '110'), (3, '3'), (95, '95'), (75, '75'), (71, '71'), (20, '20'), (12, '12'), (8, '8'), (4, '4'), (15, '15'), (22, '22'), (24, '24'), (23, '23'), (54, '54'), (25, '25'), (36, '36'), (26, '26'), (31, '31'), (27, '27'), (38, '38'), (44, '44'), (66, '66'), (74, '74'), (91, '91'), (89, '89'), (84, '84'), (93, '93'), (101, '101'), (96, '96'), (102, '102'), (106, '106'), (121, '121'), (111, '111'), (116, '116'), (123, '123'), (265, '265'), (189, '189'), (154, '154'), (146, '146'), (140, '140'), (147, '147'), (159, '159'), (175, '175'), (170, '170'), (160, '160'), (166, '166'), (174, '174'), (176, '176'), (184, '184'), (192, '192'), (191, '191'), (233, '233'), (208, '208'), (194, '194'), (199, '199'), (221, '221'), (211, '211'), (215, '215'), (260, '260'), (239, '239'), (257, '257'), (280, '280'), (278, '278'), (275, '275'), (282, '282'), (284, '284'), (304, '304'), (293, '293'), (291, '291'), (28