# Implement a Binary Search Tree
A bst is either
- an empty tree or
- it has a treeNode

A treeNode is either
- empty or
- consists of key-value pair, and 3 treeNodes: leftChild, rightChild, and parent

In [46]:
class BinarySearchTree(object):
    def __init__(self):
        self.root = None
        self.size = 0
    
    # length: self --> int
    # returns total length
    def length(self):
        return self.size
    
    # __len__: self --> int
    # returns total length
    def __len__(self):
        return self.size
    
    def __iter__(self):
        return self.root.__iter__()
    
    def __str__(self):
        if self.root == None:
            return ''
        else:
            return self.root.__str__(self.root)
    
    def put(self, k, v):
        if self.root:
            self._put(k, v, self.root)
        else:
            self.root = TreeNode(k, v)
        self.size += 1
    
    def _put(self, k, v, currentNode):
        #print(k, self.root.key)
        if k == currentNode.key:
            currentNode.val = v
        elif k < currentNode.key:
            if currentNode.hasLeftChild():
                self._put(k, v, currentNode.leftChild)
            else:
                currentNode.leftChild = TreeNode(k, v, parent=currentNode)
        else:
            if currentNode.hasRightChild():
                self._put(k, v, currentNode.rightChild)
            else:
                currentNode.rightChild = TreeNode(k, v, parent=currentNode)
    
    def __setitem__(self, k, v):
        self.put(k, v)
    
    def get(self, k):
        if self.root:
            return self._get(k, self.root)
        else:
            return None
        
    def _get(self, k, currentNode):
        if currentNode == None:
            return None
        elif k == currentNode.key:
            return currentNode.val
        elif k < currentNode.key:
            return self._get(k, currentNode.leftChild)
        else:
            return self._get(k, currentNode.rightChild)
    
    def __getitem__(self, k):
        return self.get(k)
    
    def remove(self, k):
        if self.root:
            self._remove(k, self.root)
            self.size -= 1
    
    # there are 3 cases:
    # the node is a leaf
    # the node has 1 child
    # the node has 2 children
    def _remove(self, k, currentNode):
        if currentNode != None:
            #print(currentNode.key)
            
            if k < currentNode.key:
                self._remove(k, currentNode.leftChild)
            elif k > currentNode.key:
                self._remove(k, currentNode.rightChild)
            else:
                if currentNode.isLeaf():
                    #print(k)
                    if currentNode.isRoot():
                        self.root = None
                    elif currentNode.isLeftChild():
                        #print("is left", k)
                        #print(currentNode.key)
                        #print(currentNode.__str__(currentNode.parent.parent))
                        #print(self)
                        currentNode.parent.leftChild = None
                    else:
                        currentNode.parent.rightChild = None
                        
                elif not currentNode.hasBothChildren():
                    child = currentNode.rightChild
                    if currentNode.hasLeftChild():
                        child = currentNode.leftChild
                    
                    if currentNode.isRoot():
                        self.root = child
                    elif currentNode.isLeftChild():
                        currentNode.parent.leftChild = child
                        child.parent = currentNode.parent
                    else:
                        currentNode.parent.rightChild = child
                        child.parent = currentNode.parent
                else:
                    successor = self.findSuccessor(k, currentNode.rightChild)
                    
                    currentNode.key = successor.key
                    currentNode.val = successor.val
                    self._remove(successor.key, successor)
    
    def findSuccessor(self, k, currentNode):
        while currentNode.hasLeftChild():
            currentNode = currentNode.leftChild
        
        return currentNode
                    
                    
                
    def __delitem__(self, k):
        self.remove(k)
                
    

In [47]:
class TreeNode(object):
    def __init__(self, key, val, leftChild=None, rightChild=None, parent=None):
        self.key = key
        self.val = val
        self.leftChild = leftChild
        self.rightChild = rightChild
        self.parent = parent

    def hasLeftChild(self):
        return self.leftChild

    def hasRightChild(self):
        return self.rightChild

    def isLeftChild(self):
        return self.parent and self.parent.leftChild == self

    def isRightChild(self):
        return self.parent and self.parent.rightChild == self

    def isRoot(self):
        return not self.parent

    def isLeaf(self):
        return not (self.rightChild or self.leftChild)

    def hasAnyChildren(self):
        return self.rightChild or self.leftChild

    def hasBothChildren(self):
        return self.rightChild and self.leftChild

    def replaceNodeData(self,key,value,lc,rc):
        self.key = key
        self.val = value
        self.leftChild = lc
        self.rightChild = rc
        if self.hasLeftChild():
            self.leftChild.parent = self
        if self.hasRightChild():
            self.rightChild.parent = self
            
    def __str__(self, currentNode, depth=0):
        ans = '[' + str(self.key) + ', ' + str(self.val) + ']\n'
        
        if not currentNode.hasAnyChildren():
            return ans
        else:
            if currentNode.hasLeftChild():
                ans += ' ' * (depth * 2 + 2) + 'L: ' + currentNode.leftChild.__str__(currentNode.leftChild, depth=depth+1)
            if self.hasRightChild():
                ans += ' ' * (depth * 2 + 2) + 'R: '  + currentNode.rightChild.__str__(currentNode.rightChild, depth=depth+1)
            
            return ans
        

In [48]:
def test():
    x = BinarySearchTree()
    for i in [5,2,7,1,3,4,6,8]:
        x.put(i,i)
    print([5,2,7,1,3,4,6,8])
    print(x)
    
    
    for i in [1,100,23,4,-12]:
        x.put(i,i)
    print([1,100,23,4,-12])
    print(x)
    
    for i in [1,2,3,5,8,-12,100,10]:
        pass
        print(i, x[i])
    
    for i in [1,5,7,9,100,-12]:
        del x[i]
        print(i, x)
    
    return x
    
    

In [49]:
x = test()

[5, 2, 7, 1, 3, 4, 6, 8]
[5, 5]
  L: [2, 2]
    L: [1, 1]
    R: [3, 3]
      R: [4, 4]
  R: [7, 7]
    L: [6, 6]
    R: [8, 8]

[1, 100, 23, 4, -12]
[5, 5]
  L: [2, 2]
    L: [1, 1]
      L: [-12, -12]
    R: [3, 3]
      R: [4, 4]
  R: [7, 7]
    L: [6, 6]
    R: [8, 8]
      R: [100, 100]
        L: [23, 23]

1 1
2 2
3 3
5 5
8 8
-12 -12
100 100
10 None
1 [5, 5]
  L: [2, 2]
    L: [-12, -12]
    R: [3, 3]
      R: [4, 4]
  R: [7, 7]
    L: [6, 6]
    R: [8, 8]
      R: [100, 100]
        L: [23, 23]

5 [6, 6]
  L: [2, 2]
    L: [-12, -12]
    R: [3, 3]
      R: [4, 4]
  R: [7, 7]
    R: [8, 8]
      R: [100, 100]
        L: [23, 23]

7 [6, 6]
  L: [2, 2]
    L: [-12, -12]
    R: [3, 3]
      R: [4, 4]
  R: [8, 8]
    R: [100, 100]
      L: [23, 23]

9 [6, 6]
  L: [2, 2]
    L: [-12, -12]
    R: [3, 3]
      R: [4, 4]
  R: [8, 8]
    R: [100, 100]
      L: [23, 23]

100 [6, 6]
  L: [2, 2]
    L: [-12, -12]
    R: [3, 3]
      R: [4, 4]
  R: [8, 8]
    R: [23, 23]

-12 [6, 6]
  L: [2, 