diff --git a/binarytree/__init__.py b/binarytree/__init__.py index 7e20116..a144511 100644 --- a/binarytree/__init__.py +++ b/binarytree/__init__.py @@ -321,6 +321,57 @@ def _get_tree_properties(root): } +def get_parent_node(root, node): + """Search from the binary tree and return the parent node for require node. + + :param root: Root node of the binary tree. + :rtype: binarytree.Node + :param node: Require node you want to get its parent node. + :rtype: binarytree.Node + :return: The parent node of require node. + :rtype: binarytree.Node + + **Example**: + + .. doctest:: + + >>> from binarytree import Node, get_parent_node + >>> root = Node(0) + >>> root.left = Node(1) + >>> root.right = Node(2) + >>> root.left.left = Node(3) + >>> print (root) + >>> 0 + / \ + 1 2 + / + 3 + >>> print (get_parent_node(root, root.left.left)) + >>> 1 + / + 3 + """ + if root is node or root is None or node is None: + return None + node_stack = [] + while True: + if root is not None: + node_stack.append(root) + if root.left is node: + return root + else: + root = root.left + elif len(node_stack) > 0: + root = node_stack.pop() + if root.right is node: + return root + else: + root = root.right + else: + break + return None + + class Node(object): """Represents a binary tree node. diff --git a/tests/test_tree.py b/tests/test_tree.py index d3e681f..8def0cf 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -6,6 +6,7 @@ import pytest from binarytree import Node, build, tree, bst, heap +from binarytree import get_parent_node from binarytree.exceptions import ( NodeValueError, NodeIndexError, @@ -884,3 +885,18 @@ def test_heap_float_values(): assert root.min_leaf_depth == root_copy.min_leaf_depth assert root.min_node_value == root_copy.min_node_value + 0.1 assert root.size == root_copy.size + + +@pytest.mark.order14 +def test_get_parent_node(): + root = Node(0) + root.left = Node(1) + root.right = Node(2) + root.left.left = Node(3) + root.right.right = Node(4) + assert get_parent_node(root, root.left.left) == root.left + assert get_parent_node(root, root.left) == root + assert get_parent_node(root, root) is None + assert get_parent_node(root, root.right.right) == root.right + assert get_parent_node(root, root.right) == root + assert get_parent_node(root, Node(5)) is None