In [None]:
from IPython.display import HTML, display

class TreeNode: 
    
    def __init__(self, val, left=None, right=None): 
        self.val = val
        self.left = None
        self.right = None 
    
    def __str__(self):
        return "{val: " + str(self.val) + "}"     

In [None]:
"""
pre order traverse
"""            
def tree_pre_order(root):
    if root: 
        yield root.val
        if root.left:
            # left 
            for d in tree_pre_order(root.left):
                yield d
        if root.right:
            # right
            for d in tree_pre_order(root.right):
                yield d

"""
in order traverse
"""
def tree_in_order(root):
    if root: 
        if root.left:
            # left 
            for d in tree_in_order(root.left):
                yield d
        yield root.val
        if root.right:
            # right
            for d in tree_in_order(root.right):
                yield d  

"""
post order traverse
"""
def tree_post_oder(root):
    if root: 
        if root.left:
            # left 
            for d in tree_post_oder(root.left):
                yield d
        if root.right:
            # right
            for d in tree_post_oder(root.right):
                yield d 
        yield root.val

"""
level order traverse
"""
def tree_level_order(root, reverse=True):
    result = []
    if root: 
        q = []
        q.append(root)
        while q: 
            level_size = len(q)
            cur_level = []
            for _ in range(level_size):
                cur = q.pop(0)
                cur_level.append(cur.val)  # add node to current level
                if cur.left:
                    q.append(cur.left)
                if cur.right:
                    q.append(cur.right)
            if reverse: 
                # top down
                result.append(cur_level)
            else: 
                # buttom up 
                result.insert(0, cur_level)

    return result

# Binary Search Tree 

to visualize, visit [https://www.cs.usfca.edu/~galles/visualization/BST.html](https://www.cs.usfca.edu/~galles/visualization/BST.html)

In [None]:
class BinarySearchTree: 
    
    def __init__(self, root=None): 
        self.root = root
    
    def first(self):
        if self.is_empty(): 
            return None 
        return self._min(self.root).val
    
    def _min(self, node):
        if node.left:
            return self._min(node.left)
        return node 
    
    def last(self):
        if self.is_empty(): 
            return None 
        return self._max(self.root).val
    
    def _max(self, node):
        if node.right:
            return self._max(node.right)
        return node 
    
    def before(self, val):
        node = self._search(val, self.root)
        if node.left:
            return _max(node.left)
        return None
    
    def after(self, val):
        node = self._search(val, self.root)
        if node.right:
            return _min(node.right)
        return None
        
    def search(self, val):
        if self.is_empty():
            return None
        return self._search(val, self.root)
    
    """
    search val in a BST, node is the root
    return the node if found 
    """
    def _search(self, val, node):
        if node is None: 
            return None
        elif val < node.val:
            # search the left child 
            return self._search(val, node.left)
        elif val > node.val:
            # search the right child 
            return self._search(val, node.right)
        # val == node.val
        return node 
    
    def insert(self, val):
        self.root = self._insert(val, self.root)
    
    """
    insert val in a BST, node is the root
    return the root of the new BST
    """
    def _insert(self, val, node):
        if node is None: 
            node = TreeNode(val) 
        elif val < node.val:
            # insert to the left 
            node.left = self._insert(val, node.left)
        elif val > node.val:
            # insert to the right
            node.right = self._insert(val, node.right)
        return node
        
    def delete(self, val):
        self.root = self._delete(val, self.root)
        return self.root
    
    """
    del val from a BST, node is the root
    return the root of the new BST
    """
    def _delete(self, val, node):
        if node is None:
            return None
        
        if val < node.val:
            # recusive call: left substree 
            node.left = self._delete(val, node.left)
            return node
        elif val > node.val: 
            # recusive call: right substree 
            node.right = self._delete(val, node.right)
            return node
        else: 
            # val == node.val
            if node.left is None: 
                # node has only right child 
                right = node.right 
                node.right = None
                return right
            if node.right is None: 
                # node has only left child 
                left = node.left 
                node.left = None
                return left
            # node has both left & right child 
            # replacement: max node of left subtree or min node of right substree 
            # max node of the left subtree
            repl_node = self._max(node.left)
            repl_node.left = self._delete_max(node.left)
            # min node of right substree 
            # repl_node = self._min(node.right)
            # repl_node.right = self._delete_min(node.right)
            repl_node.right = node.right
            node.left = node.right = None 
            return repl_node
            
    """
    delete min node from a BST, node is the root
    return the root of the new BST
    """
    def _delete_min(self, node):  
        if node.left is None: 
            # the leftmost node in the tree
            right = node.right
            node.right = None 
            return right
        
        # recursive call 
        node.left = self._delete_min(node.left)
        return node
    
    """
    delete max node from a BST, node is the root
    return the root of the new BST
    """
    def _delete_max(self, node):  
        if node.right is None: 
            # the rightmost node in the tree
            left = node.left
            node.left = None 
            return left
        
        # recursive call 
        node.right = self._delete_max(node.right)
        return node
             
    def is_empty(self):
        return not self.root

bst = BinarySearchTree()
data = [50, 77, 55, 29, 10, 30, 66, 18, 80, 51, 90, 17, 88, 79]
for v in data:
    bst.insert(v) 

print("      BST:", tree_level_order(bst.root))
# bst in_order traverse returns a sorted list 
print("   sorted: ", " → ".join([str(v) for v in tree_in_order(bst.root)]))
print("      min: ", bst.first())
print("      max: ", bst.last())
v = 50
print(f"search {v}: ", bst.search(v))

html = """<img src='https://mth252.fastzhong.com/notebooks/binary_search_tree1.png' style='width:70%'>"""
display(HTML(html))


v = 10
print(f"delete {v}:", tree_level_order(bst.delete(v)))

html = """<img src='https://mth252.fastzhong.com/notebooks/binary_search_tree2.png' style='width:70%'>"""
display(HTML(html))

v = 77
print(f"delete {v}:", tree_level_order(bst.delete(v)))

html = """<img src='https://mth252.fastzhong.com/notebooks/binary_search_tree3.png' style='width:70%'>"""
display(HTML(html))


# AVL 
to visualize, visit [https://www.cs.usfca.edu/~galles/visualization/AVLtree.html](https://www.cs.usfca.edu/~galles/visualization/AVLtree.html)

In [None]:
class AvlNode: 
    
    def __init__(self, val, left=None, right=None): 
        self.val = val
        self.left = None
        self.right = None 
        self.height = 1 # leaf node by default 
    
    def __str__(self):
        return "{val: " + str(self.val) + " height: " + str(self.height) + "}" 

In [None]:
class AvlTree(BinarySearchTree):
    
    def __init__(self, root=None): 
        self.root = root    
    
    @staticmethod
    def node_height(node): 
        if node: 
            return node.height
        return 0
    
    @staticmethod
    def node_bf(node): 
        if node:
            return AvlTree.node_height(node.left) - AvlTree.node_height(node.right)
        return 0 # empty tree 
    
    """
    refer to the slide
    """
    def right_rotate(self, x): 
        y = x.left 
        t3 = y.right 
    
        # right rotate 
        y.right = x 
        x.left = t3 
        
        # update height 
        x.height = max(AvlTree.node_height(x.left), AvlTree.node_height(x.right)) + 1 
        y.height = max(AvlTree.node_height(y.left), AvlTree.node_height(y.right)) + 1 
        
        return y
    
    """
    refer to the slide
    """
    def left_rotate(self, x): 
        y = x.right 
        t3 = y.left
    
        # left rotate 
        y.left = x 
        x.right = t3 
        
        # update height 
        x.height = max(AvlTree.node_height(x.left), AvlTree.node_height(x.right)) + 1 
        y.height = max(AvlTree.node_height(y.left), AvlTree.node_height(y.right)) + 1 
        
        return y
    
    def insert(self, val):
        self.root = self._insert(val, self.root)
    
    """
    insert val in a AVL, node is the root
    return the root of the new AVL
    """
    def _insert(self, val, node):
        if node is None: 
            return AvlNode(val) 
        elif val < node.val:
            # insert to the left 
            node.left = self._insert(val, node.left)
        elif val > node.val:
            # insert to the right
            node.right = self._insert(val, node.right)
        
        # update height
        node.height = 1 + max(AvlTree.node_height(node.left), AvlTree.node_height(node.right))
        
        # check balance factor 
        bf = AvlTree.node_bf(node)
        # if abs(bf)  > 1: 
        #    print(f"unbalanced when insert {val}: {bf}")
        
        # left skewed then right rotate to rebalance  
        # LL
        if bf > 1 and AvlTree.node_bf(node.left) >= 0: 
            return self.right_rotate(node)
        
        # right skewed then left rotate to rebalance 
        # RR
        if bf < -1 and AvlTree.node_bf(node.right) <= 0: 
            return self.left_rotate(node)
        
        # LR → LL
        if bf > 1 and AvlTree.node_bf(node.left) < 0: 
            node.left = self.left_rotate(node.left)
            return self.right_rotate(node)
        
        # RL → RR
        if bf < -1 and AvlTree.node_bf(node.right) > 0:
            node.right = self.right_rotate(node.right)
            return self.left_rotate(node)
        
        return node

    def delete(self, val):
        self.root = self._delete(val, self.root)
        return self.root
    
    """
    del val from a AVL, node is the root
    return the root of the new AVL
    """
    def _delete(self, val, node):
        if node is None:
            return None
        
        reb_node = None
        if val < node.val:
            # recusive call: left substree 
            node.left = self._delete(val, node.left)
            reb_node = node
        elif val > node.val: 
            # recusive call: right substree 
            node.right = self._delete(val, node.right)
            reb_node = node
        else: 
            # val == node.val
            if node.left is None: 
                # node has only right child 
                right = node.right 
                node.right = None
                reb_node = right
            elif node.right is None: 
                # node has only left child 
                left = node.left 
                node.left = None
                reb_node = left
            else: 
                # node has both left & right child 
                # replacement: max node of left subtree or min node of right substree  
                # max node of the left subtree
                repl_node = self._max(node.left) 
                repl_node.left = self._delete(repl_node.val, node.left) 
                repl_node.right = node.right
                node.left = node.right = None 
                reb_node = repl_node
        
        if reb_node is None:
            return None
        
        # update height
        reb_node.height = 1 + max(AvlTree.node_height(reb_node.left), AvlTree.node_height(reb_node.right))
        
        # check balance factor 
        bf = AvlTree.node_bf(reb_node)
        # if abs(bf)  > 1: 
        #    print(f"unbalanced when insert {val}: {bf}")
        
        # left skewed then right rotate to rebalance  
        # LL
        if bf > 1 and AvlTree.node_bf(reb_node.left) >= 0: 
            return self.right_rotate(reb_node)
        
        # right skewed then left rotate to rebalance 
        # RR
        if bf < -1 and AvlTree.node_bf(reb_node.right) <= 0: 
            return self.left_rotate(reb_node)
        
        # LR → LL
        if bf > 1 and AvlTree.node_bf(reb_node.left) < 0: 
            reb_node.left = self.left_rotate(reb_node.left)
            return self.right_rotate(reb_node)
        
        # RL → RR
        if bf < -1 and AvlTree.node_bf(reb_node.right) > 0:
            reb_node.right = self.right_rotate(reb_node.right)
            return self.left_rotate(reb_node)
        
        return reb_node
        
    
    def is_empty(self):
        return not self.root
    
"""
use inorder to check whether the tree is Binary Search Tree 
"""
def tree_is_bst(root): 
    values = [v for v in tree_in_order(root)]
    for i in range(1, len(values)): 
        if (values[i] < values[i-1]):
            return False 
    return True

"""
use inorder to check whether the tree is Binary Search Tree 
"""    
def tree_is_balanced(root): 
    if root is None:
        return True
    if abs(AvlTree.node_bf(root)) > 1: 
        return False 
    return tree_is_balanced(root.left) and tree_is_balanced(root.right)

   
html = """<img src='https://mth252.fastzhong.com/notebooks/binary_search_tree1.png' style='width:70%'>"""
display(HTML(html))

avl = AvlTree()
data = [50, 77, 55, 29, 10, 30, 66, 18, 80, 51, 90, 17, 88, 79]
for v in data:
    avl.insert(v) 
print("      AVL:", tree_level_order(avl.root))
# bst in_order traverse returns a sorted list 
print("   sorted: ", " → ".join([str(v) for v in tree_in_order(avl.root)]))
print("      min: ", avl.first())
print("      max: ", avl.last())

v = 50
print(f"search {v}: ", avl.search(v))
print("      BST? ", tree_is_bst(avl.root))
print(" balanced? ", tree_is_balanced(avl.root))

html = """<img src='https://mth252.fastzhong.com/notebooks/binary_search_tree_balanced1.png' style='width:70%'>"""
display(HTML(html))

v = 10
print(f"delete {v}:", tree_level_order(avl.delete(v)))
print("      BST? ", tree_is_bst(avl.root))
print(" balanced? ", tree_is_balanced(avl.root))

html = """<img src='https://mth252.fastzhong.com/notebooks/binary_search_tree_balanced2.png' style='width:70%'>"""
display(HTML(html))

v = 55
print(f"delete {v}:", tree_level_order(avl.delete(v)))
print("      BST? ", tree_is_bst(avl.root))
print(" balanced? ", tree_is_balanced(avl.root))

html = """<img src='https://mth252.fastzhong.com/notebooks/binary_search_tree_balanced3.png' style='width:70%'>"""
display(HTML(html))


## 2-3 Tree

to visualize, visit [https://people.ksp.sk/~kuko/gnarley-trees/23tree.html](https://people.ksp.sk/~kuko/gnarley-trees/23tree.html)

In [None]:
# 88, 66, 50, 48, 42, 17, 6, 12, 18, 33, 37, 70

## Red Black Tree

In [None]:
RED = True
BLACK = False

class RBNode: 
    
    def __init__(self, val, left=None, right=None): 
        self.val = val
        self.left = None
        self.right = None 
        self.color = RED 
    
    def __str__(self):
        str = "{Red val: " + str(self.val) + "}" if self.colr else "{Black val: " + str(self.val) + "}" 
        return str

In [None]:
class RBTree(BinarySearchTree):
    
    def __init__(self, root=None): 
        self.root = root    
    
    @staticmethod
    def is_red(node):
        if node is None: 
            return BLACK
        return node.color
    
    """
    refer to the slide
    """
    def left_rotate(self, node): 
        
        x = node.right
        node.right = x.left
        x.left = node
        x.color = node.color
        node.color = RED
        
        return x
    
    """
    refer to the slide
    """
    def flip_colors(self, node): 
    
        node.color = RED 
        node.left.color = BLACK
        node.right.color = BLACK 
    
        return node 
    
    """
    refer to the slide
    """
    def right_rotate(self, node): 
    
        x = node.left
        node.left = x.right
        x.right = node
        x.color = node.color
        node.color = RED
        
        return x
    
    def insert(self, val):
        self.root = self._insert(val, self.root)
        self.root.color = BLACK

    """
    insert val in a RBTree, node is the root
    return the root of the new RBTree
    """
    def _insert(self, val, node):
        if node is None: 
            return RBNode(val) 
        elif val < node.val:
            # insert to the left 
            node.left = self._insert(val, node.left)
        elif val > node.val:
            # insert to the right
            node.right = self._insert(val, node.right)
        
        if RBTree.is_red(node.right) and not RBTree.is_red(node.left): 
            node = self.left_rotate(node)

        if RBTree.is_red(node.left) and RBTree.is_red(node.left.left): 
            node = self.right_rotate(node)

        if RBTree.is_red(node.left) and RBTree.is_red(node.right): 
            node = self.flip_colors(node)
        
        return node
    
    def delete(self, val):
        print("TBD")
        return self.root
    

In [None]:
# test case BST vs. AVL 

import timeit
from random import randint

def gen_nums(n): 
    return [i for i in range(n)]

def test_bst(nums): 
    bst = BinarySearchTree()
    for v in nums:
        if bst.search(v):
            continue
        bst.insert(v)

def test_avl(nums): 
    avl = AvlTree()
    for v in nums:
        if avl.search(v):
            continue
        avl.insert(v)

def test_rb(nums): 
    rbt = RBTree()
    for v in nums:
        if rbt.search(v):
            continue
        rbt.insert(v)
        
import time

n = 2000
nums = gen_nums(n)
print()
print("test1 size:", len(nums))
start = time.time()
test_bst(nums)
end = time.time()
print(" BST: %.2f secs" % (end - start))
start = time.time()
test_avl(nums)
end = time.time()
print(" AVL: %.2f secs" % (end - start))
start = time.time()
test_rb(nums)
end = time.time()
print("  RB: %.2f secs" % (end - start))

# Skip List

In [None]:
from IPython.display import HTML, display
html = """<img src='https://mth252.fastzhong.com/notebooks/skip_dsa1.webp' style='width:80%'>"""
display(HTML(html))

import random

class SkipNode: 

    def __init__(self, k=None, v=None): 
        self.k = k
        self.v = v # the value attribute is not important here
        self.nxt = None    # points to next node on the same level 
        self.down = None   # points to the node on a level down  
        
    def __str__(self):
        return "{" + str(self.k) + ": " + str(self.v) + "}" 

class SkipList: 
    
    # max level allowed 
    MAX_LEVEL = 32
    
    def __init__(self):
        self.head = SkipNode()
        self.size = 0
        self.levels= 1 # levels of this skip list 
        
    def find(self, k): 
        cur = self.head  
        while cur: 
            if cur.k == k: 
                return cur 
            nxt = cur.nxt 
            if nxt and k >= nxt.k:
                # k is bigger than the next node 
                # move to next node on the same level 
                cur = nxt 
                continue
            # reach to the end
            # or k is smaller than the next node 
            # move to one level down
            cur = cur.down 
        return None
    
    def insert(self, k, v): 
        found_node = self.find(k)
        cur = found_node
        # if found, set the value for all levels
        while cur:
            cur.value = v
            cur = cur.down 
        if found_node:
            return
        
        # find the node to insert new node (to its right) "from top to bottom" 
        # store all these "key" nodes along the path
        # for those nodes, we insert the new node "from bottom to top" 
        path = [] 
        cur = self.head 
        while cur: 
            nxt = cur.nxt 
            if nxt and k > nxt.k:
                # k is bigger than the next node 
                # move to next node on the same level
                cur = nxt
                continue 
            # reach to the end
            # or k is smaller than the next node 
            # move to one level down
            path.append(cur)
            cur = cur.down
        
        l = 1 
        down_node = None 
        tail = 0
        while path: 
            new_node = SkipNode(k, v)
            new_node.down = down_node
            down_node = new_node # keep down_node after move to upper level 
            cur = path.pop()
            # insert to the right of cur  
            if cur.nxt: 
                new_node.nxt = cur.nxt 
                cur.nxt = new_node 
            else: 
                # cur is already the end 
                cur.nxt = new_node
                tail += 1
                if tail > 1: 
                    break
            # need to move to one level up
            if l > SkipList.MAX_LEVEL: 
                break; # already reach the max 
            # flip coin to decide increase level and insert a new node 
            # 50% vs 50%: can we improve the randomization and control the possibility?  
            if random.randint(0, 2): 
                break; # bad luck 
            l += 1 
            # move to one level up in the path 
            # create a new level if necessary 
            if l > self.levels:
                self.levels = l 
                # create a head for this new level 
                new_head = SkipNode()
                new_head.down = self.head 
                self.head = new_head 
                path.append(new_head)
            
    def delete(self, k): 
        cur = self.head
        found = False
        while cur: 
            nxt = cur.nxt 
            if nxt: 
                if nxt.k < k:
                    # k still on the right, move to next node 
                    cur = nxt
                    continue
                if nxt.k == k:
                    # k is found
                    found = True
                    cur.nxt = nxt.nxt # remove nxt node 
            # have to move down one level and continue to remove k
            cur = cur.down
        if found:
            self.size -= 1
            cur = self.head 
            while self.levels > 1 and cur.nxt is None:
                # empty level 
                # move down one level 
                cur = cur.down
                self.head = cur.down 
                self.levels -= 1
    
    def print_me(self):
        print("skip list: %d level" % self.levels)
        size = 4 # fix len
        cur_head = self.head 
        last_head = cur_head
        while last_head.down: 
            last_head = last_head.down 
        while cur_head:
            cur = cur_head.nxt 
            last_cur = last_head.nxt
            keys = []
            while last_cur:
                if cur and cur.k == last_cur.k: 
                    keys.append(str(cur.k).rjust(size) + " → ") 
                    cur = cur.nxt 
                    last_cur = last_cur.nxt 
                else: 
                    keys.append(" " * size + "   ")
                    last_cur = last_cur.nxt 
            # print current level 
            print("head → " + "".join(keys) + "null")
            # move down one level 
            cur_head = cur_head.down 
            
    def __is_empty__(self): 
        return self.size == 0

    def __len__(self): 
        return self.size
        
skipL = SkipList()
v = "node" # a static value, just for demo 

skipL.insert(3, v)
skipL.insert(4, v)
skipL.insert(6, v)
skipL.insert(7, v)
skipL.insert(8, v)
skipL.insert(10, v)
skipL.insert(12, v)
skipL.print_me()
key = 10
print(f"search key={key}: ", skipL.find(key))
key = 100
print(f"search key={key}: ", skipL.find(key))

key = 10
print(f"delete key={key}")
skipL.delete(key)
skipL.print_me()
print(f"search key={key}: ", skipL.find(key))

# Exercise 

In [None]:
# Map backed by AVL tree

class AvlMap: 
    
    def __init__(self): 
        self.avl = None
        self.size = 0
    
    def insert(self, key, val):
        return 
    
    def get(self, key):
        return 
    
    def remove(self, key): 
        return 
