In [1]:
# Copyright(C) 2021 刘珅珅
# Environment: python 3.7
# Date: 2021.3.29
# 二叉查找树的增删改查

In [3]:
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left, self.right = None, None

In [4]:
class BinaryTree:
    def __init__(self):
        # do intialization if necessary
        self.root = None
    
    ## 插入
    def insert(self, val):
        if not self.root:
            self.root = TreeNode(val)
            return self.root
        
        tmp = self.root
        while tmp:
            if tmp.val > val:
                if not tmp.left:
                    tmp.left = TreeNode(val)
                    break
                tmp = tmp.left
            else:
                if not tmp.right:
                    tmp.right = TreeNode(val)
                    break
                tmp = tmp.right
        return self.root
    
    ## 查找
    def find(self, root, val):
        if not root:
            return None
        tmp = root
        while tmp:
            if tmp.val > val:
                tmp = tmp.left
            elif tmp.val < val:
                tmp = tmp.right
            else:
                return tmp
        
        return None
    
    
    ## 修改，为了保证更改后仍然是一个BST，需要先删除再插入
    def update(self, target, val):
        root, flag = self.remove(self.root, target)
        if flag:
            return self.insert(val)
        return root
    
    def find_parent(self, parent, root, val):
        while root:
            if root.val > val:
                parent = root
                root = root.left
            elif root.val < val:
                parent = root
                root = root.right
            else:
                break
        return parent
        
    
    ## 删除
    def remove(self, root, val):
        if not root:
            return None, False
        
        ## 定义哨兵结点dummy，dummy.left永远指向root
        ## 结点删除后，要返回根结点
        ## 定义哨兵结点是个很好的习惯
        dummy = TreeNode(0)
        dummy.left = root
        parent = self.find_parent(dummy, root, val)
        node = None
        if parent.left and parent.left.val == val:
            node = parent.left
        elif parent.right and parent.right.val == val:
            node = parent.right
        else:
            ## val对应的结点不存在
            return dummy.left, False
        
        ## 删除node结点
        self.delete_node(parent, node)
        return dummy.left, True
    
    def delete_node(self, parent, node):
        ## 用node右子树的最小结点替代node结点
        if not node.right:
            if parent.left == node:
                parent.left = node.left
            elif parent.right == node:
                parent.right = node.left
        else:
            tmp = node.right
            father = node
            ## 查找最小结点
            while tmp.left:
                father = tmp
                tmp = tmp.left
            
            ## tmp没有左孩子，否则上面的while循环就会把tmp.left赋给tmp
            ## 修改tmp的父结点father的左右孩子
            if father.left == tmp:
                father.left = tmp.right
            else:
                father.right = tmp.right
            
            ## 用tmp替换node结点
            if parent.left == node:
                parent.left = tmp
            else:
                parent.right = tmp
            tmp.left = node.left
            tmp.right = node.right

                    
    
    ## 中序遍历：左子树->根->右子树
    def inorder_traversal(self, root):
        stack = []
        result = []
        tmp = root
        while tmp:
            stack.append(tmp)
            tmp = tmp.left
        
        while stack:
            tmp = stack[-1]
            result.append(tmp.val)
            if tmp.right:
                tmp = tmp.right
                while tmp:
                    stack.append(tmp)
                    tmp = tmp.left
            else:
                tmp = stack.pop()
                while stack and stack[-1].right == tmp:
                    tmp = stack.pop()
    
    ## 前序遍历：根->左子树->右子树
    def preorder_traversal(self, root):
        if not root:
            return []
        
        stack = [root]
        results = []
        while stack:
            node = stack.pop()
            results.append(node.val)
            if node.right:
                stack.append(node.right)
            if node.left:
                stack.append(node.left)
        return results
    
    ## 后序遍历：左子树->右子树->根
    def postorder_traversal(self, root):
        if not root:
            return []
        
        stack = []
        results = []
        current = root
        last_visited = None
        while current or stack:
            while current:
                stack.append(current)
                current = current.left
            current = stack[-1]
            
            ## 右子树没有访问过，就将其右子树压入栈中，此时current结点时没有出栈
            ## 当没有右子树或右子树已经访问过时，当前结点current要出栈
            ## last_visited用于记录所有出栈的结点
            if current.right and current.right != last_visited:
                current = current.right
            else:
                current = stack.pop()
                results.append(current.val)
                last_visited = current
                current = None
        return results
                
                
        

        
                



        

In [6]:
tree = BinaryTree()
tree.insert(6)
tree.insert(5)
tree.insert(8)
root = tree.insert(1)
tree.inorder_traversal(root)
root, flag = tree.remove(root, 1)
tree.inorder_traversal(root)
root = tree.update(8, 8)
tree.preorder_traversal(root)

[6, 5, 8]
