# Data structures:  Binary Search Tree

A binary search tree is a binary tree with the following property: every node in the left subtree of a node are smaller, every node in its right subtree are larger.

Search, insertion and deletion are in O(log n) on average and O(n) worst case (for degenerate trees).


In [653]:
class BstNode(object):
    
    def __init__(self, key, record=None, parent=None):
        self.key = key
        self.record = record
        self.right = None
        self.left = None
        self.parent = parent
    
    def __str__(self):
        return 'Node(%r)' % self.key
        
    def __repr__(self):
        return 'Node(key=%r, left=%s, right=%s, parent=%s)' % (self.key, self.left, self.right, self.parent)
        
    def search(self, key):
        """Returns node by key or None if not found."""
        if self.key == key:
            return self
        elif self.key < key:
            if self.right:
                return self.right.search(key)
            else:
                return None
        else: # self.key > key
            if self.left:
                return self.left.search(key)
            else:
                return None

    def insert(self, key, record=None):
        """Inserts or updates a key with given record."""
        if self.key == key:
            self.record = record # we could also raise an exception
        elif self.key < key:
            if self.right:
                self.right.insert(key, record)
            else:
                self.right = BstNode(key, record, parent=self)
        else: # self.key > key
            if self.left:
                self.left.insert(key, record)
            else:
                self.left = BstNode(key, record, parent=self)

    def delete(self, key=None):
        """Deletes a node by key, raises KeyError if not found."""
        key = key or self.key
        
        # Node to delete is on the right.
        if self.key < key:
            if self.right:
                self.right.delete(key)
            else:
                raise KeyError('Key %r not found' % key)

        # Node to delete is on the left.
        elif self.key > key:
            if self.left:
                self.left.delete(key)
            else:
                raise KeyError('Key %r not found' % key)

        # This is the node to delete.
        else:
            # First case, no children. Just delete the node.
            if not self.right and not self.left:
                if self.parent.left and self.parent.left.key == self.key:
                    self.parent.left = None
                else:
                    assert self.parent.right and self.parent.right.key == self.key
                    self.parent.right = None

            # Second case, only a left child. Just replace the node with the child.
            elif not self.right:
                self._replace_in_parent(self.left)
            
            # Third case, only a right child. Just replace the node with the child.
            elif not self.left:
                self._replace_in_parent(self.right)
            
            # Fourth case, two children. We find the in-order successor, which is necessarily the right
            # child of its parent, and which might have up to one left child. If it has no child, we delete it.
            # If it has a left child, we replace it with its left child. And finally we replace the node to
            # delete with its in-order successor, which means replacing key and record while keeping left and
            # right children as well as parent.
            else:
                successor = self.right._min_element()
                successor._replace_in_parent(successor.left) # can be None
                self.key = successor.key
                self.record = successor.record

    def traverse(self, func):
        """Apply a callback to all nodes in-order."""
        if self.left:
            self.left.traverse(func)
        func(self.key, self.record)
        if self.right:
            self.right.traverse(func)
        
    def display(self):
        """Pretty-prints the BST."""
        lines, _, _, _ = self._display_aux()
        for line in lines:
            print(line)
            
    def _min_element(self):
        if self.left:
            return self.left._min_element()
        else:
            return self
            
    def _replace_in_parent(self, replace_with=None):
        if self.parent:
            if self.parent.left.key == self.key:
                self.parent.left = replace_with
            else:
                self.parent.right = replace_with
            if replace_with:
                replace_with.parent = self.parent
    
    def _display_aux(self):
        """Returns list of strings, width, height, and horizontal coordinate of the root."""
        # No child.
        if not self.right and not self.left:
            line = '%s' % self.key
            width = len(line)
            height = 1
            middle = width // 2
            return [line], width, height, middle
        
        # Only left child.
        if not self.right:
            lines, n, p, x = self.left._display_aux()
            s = '%s' % self.key
            u = len(s)
            first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s
            second_line = x * ' ' + '/' + (n - x - 1 + u) * ' '
            shifted_lines = [line + u * ' ' for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, n + u // 2
        
        # Only right child.
        if not self.left:
            lines, n, p, x = self.right._display_aux()
            s = '%s' % self.key
            u = len(s)
            first_line = s + x * '_' + (n - x) * ' '
            second_line = (u + x) * ' ' + '\\' + (n - x - 1) * ' '
            shifted_lines = [u * ' ' + line for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, u // 2
        
        # Two children.
        left, n, p, x = self.left._display_aux()
        right, m, q, y = self.right._display_aux()
        s = '%s' % self.key
        u = len(s)
        first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s + y * '_' + (m - y) * ' '
        second_line = x * ' ' + '/' + (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
        if p < q:
            left += [n * ' '] * (q - p)
        elif q < p:
            right += [m * ' '] * (p - q)
        zipped_lines = zip(left, right)
        lines = [first_line, second_line] + [a + u * ' ' + b for a, b in zipped_lines]
        return lines, n + m + u, max(p, q) + 2, n + u // 2
    
    ### The following is for priority queues.
    
    def find_min(self):
        if self.left:
            return self.left.find_min()
        return self
    
    def find_max(self):
        if self.right:
            return self.right.find_max()
        return self
    
    def delete_min(self):
        self.find_min().delete()
    
    def delete_max(self):
        self.find_max().delete()
    
    def pop_min(self):
        node = self.find_min()
        node.delete()
        return node
    
    def pop_max(self):
        node = self.find_max()
        node.delete()
        return node
    

class Bst(object):
    
    def __init__(self):
        self.root = BstNode(key=None)
        
    def empty(self):
        return self.root.left is None
        
    def search(self, key):
        if self.empty():
            raise KeyError('Key %r not found', key)
        return self.root.left.search(key)
        
    def insert(self, key, record=None):
        if self.empty():
            self.root.left = BstNode(key, record, parent=self.root)
        else:
            self.root.left.insert(key, record)
    
    def delete(self, key):
        if self.empty():
            raise KeyError('Key %r not found', key)
        self.root.left.delete(key)
    
    def display(self):
        if self.empty():
            print()
        else:
            self.root.left.display()
    
    def traverse(self, func):
        if not self.empty():
            self.root.left.traverse(func)
            
    def find_min(self):
        if self.empty():
            raise ValueError('Empty BST')
        return self.root.left.find_min()
            
    def find_max(self):
        if self.empty():
            raise ValueError('Empty BST')
        return self.root.left.find_max()
            
    def delete_min(self):
        if self.empty():
            raise ValueError('Empty BST')
        self.root.left.delete_min()
            
    def delete_max(self):
        if self.empty():
            raise ValueError('Empty BST')
        self.root.left.delete_max()
            
    def pop_min(self):
        if self.empty():
            raise ValueError('Empty BST')
        return self.root.left.pop_min()
            
    def pop_max(self):
        if self.empty():
            raise ValueError('Empty BST')
        return self.root.left.pop_max()

In [654]:
import random

b = Bst()
for _ in range(70):
    b.insert(random.randint(0, 100))
b.display()

 _3________________________________________________________________________                 
/                                                                          \                
0                                                   ______________________86_               
 \                                                 /                         \              
 1        ________________________________________54_______                 88___           
         /                                                 \                     \          
     ___16___________                                 ____64___________         90_____     
    /                \                               /                 \       /       \    
    6__     ________25_                             56___       ______78_     89    __95_   
   /   \   /           \                                 \     /         \         /     \  
   5  12  17_____     26_______                         61    70_     

In [655]:
b = Bst()
b.display()

b.insert(50)
b.insert(25)
b.insert(75)
b.insert(12)
b.insert(36)
b.insert(64)
b.insert(82)
b.insert(90)
b.insert(80)
b.insert(70)
b.insert(60)
b.insert(40)
b.insert(30)
b.insert(20)
b.insert(10)
b.display()

for n in [90, 82, 60, 64, 25, 12, 50, 36]:
    node = b.search(n)
    print('Deleting %r' % node)
    b.delete(n)
    b.display()


        ______50_______       
       /               \      
    __25___         __75___   
   /       \       /       \  
  12_     36_     64_     82_ 
 /   \   /   \   /   \   /   \
10  20  30  40  60  70  80  90
Deleting Node(key=90, left=None, right=None, parent=Node(82))
        ______50_______     
       /               \    
    __25___         __75___ 
   /       \       /       \
  12_     36_     64_     82
 /   \   /   \   /   \   /  
10  20  30  40  60  70  80  
Deleting Node(key=82, left=Node(80), right=None, parent=Node(75))
        ______50_______   
       /               \  
    __25___         __75_ 
   /       \       /     \
  12_     36_     64_   80
 /   \   /   \   /   \    
10  20  30  40  60  70    
Deleting Node(key=60, left=None, right=None, parent=Node(64))
        ______50_____   
       /             \  
    __25___       __75_ 
   /       \     /     \
  12_     36_   64_   80
 /   \   /   \     \    
10  20  30  40    70    
Deleting Node(key=64, lef

## Application: Tree sort

We insert all values in a BST. Each insertion is O(log n) average, and traversal is O(n) so total is O(n log n) average.

In [656]:
def tree_sort(values):
    if not values:
        return []
    b = Bst()
    for i, value in enumerate(values):
        b.insert(key=(value, i), record=value) # to allow duplicates
    sorted_values = []    
    b.traverse(lambda key, value: sorted_values.append(value))
    return sorted_values

In [657]:
import random

values = []
for _ in range(100):
    values.append(random.randint(1, 1000))
print(values)

[29, 291, 705, 318, 224, 227, 726, 744, 945, 896, 363, 120, 416, 695, 276, 982, 967, 491, 954, 305, 236, 334, 750, 534, 250, 28, 11, 959, 149, 875, 333, 33, 875, 475, 822, 132, 158, 440, 818, 370, 984, 123, 53, 820, 260, 590, 330, 832, 543, 140, 806, 963, 7, 409, 44, 938, 525, 779, 126, 743, 317, 287, 862, 44, 778, 723, 176, 806, 973, 989, 555, 483, 752, 527, 135, 299, 489, 94, 644, 667, 207, 431, 1, 740, 208, 587, 918, 155, 83, 713, 212, 499, 661, 795, 136, 732, 944, 298, 163, 770]


In [658]:
print(tree_sort(values))

[1, 7, 11, 28, 29, 33, 44, 44, 53, 83, 94, 120, 123, 126, 132, 135, 136, 140, 149, 155, 158, 163, 176, 207, 208, 212, 224, 227, 236, 250, 260, 276, 287, 291, 298, 299, 305, 317, 318, 330, 333, 334, 363, 370, 409, 416, 431, 440, 475, 483, 489, 491, 499, 525, 527, 534, 543, 555, 587, 590, 644, 661, 667, 695, 705, 713, 723, 726, 732, 740, 743, 744, 750, 752, 770, 778, 779, 795, 806, 806, 818, 820, 822, 832, 862, 875, 875, 896, 918, 938, 944, 945, 954, 959, 963, 967, 973, 982, 984, 989]


## Application: double-ended priority queue

In [659]:
queue = Bst()
queue.insert(50, 'Call mom')
queue.insert(30, 'File taxes')
queue.insert(60, 'Buy milk')
queue.insert(90, 'Read Tolstoy')
queue.insert(10, 'Urgent! Call Pete')

def length(bst):
    global n
    n = 0
    def add_one(key, val):
        global n
        n += 1
    bst.traverse(add_one)
    return n

while not queue.empty():
    print('Remaining: %d. Top: "%s"' % (length(queue), queue.pop_min().record))

Remaining: 5. Top: "Urgent! Call Pete"
Remaining: 4. Top: "File taxes"
Remaining: 3. Top: "Call mom"
Remaining: 2. Top: "Buy milk"
Remaining: 1. Top: "Read Tolstoy"


## AVL tree: self-balancing BST

In [660]:
b = Bst()
for i in range(50):
    b.insert(i)
b.display()

0                                                                                         
 \                                                                                        
 1                                                                                        
  \                                                                                       
  2                                                                                       
   \                                                                                      
   3                                                                                      
    \                                                                                     
    4                                                                                     
     \                                                                                    
     5                                                                                    

In [661]:
class AvlNode(object):
    
    def __init__(self, key, record=None, parent=None):
        self.key = key
        self.record = record
        self.right = None
        self.left = None
        self.parent = parent
        self.height = 1
    
    def __str__(self):
        return 'Node(%r)' % self.key
        
    def __repr__(self):
        return 'Node(key=%r, left=%s, right=%s, parent=%s)' % (self.key, self.left, self.right, self.parent)
        
    def search(self, key):
        """Returns node by key or None if not found."""
        if self.key == key:
            return self
        elif self.key < key:
            if self.right:
                return self.right.search(key)
            else:
                return None
        else: # self.key > key
            if self.left:
                return self.left.search(key)
            else:
                return None

    def insert(self, key, record=None):
        """Inserts or updates a key with given record."""
        if self.key == key:
            self.record = record # we could also raise an exception
        elif self.key < key:
            if self.right:
                self.right.insert(key, record)
            else:
                self.right = AvlNode(key, record, parent=self)
                self.right.update_heights()
                self.right.rebalance()
        else: # self.key > key
            if self.left:
                self.left.insert(key, record)
            else:
                self.left = AvlNode(key, record, parent=self)
                self.left.update_heights()
                self.left.rebalance()
    
    def update_heights(self):
        self.height = 1
        if self.right:
            self.height = self.right.height + 1
        if self.left:
            self.height = max(self.height, self.left.height + 1)
        if self.parent:
            self.parent.update_heights()
            
    def balance(self):
        if self.right:
            right_height = self.right.height
        else:
            right_height = 0
        if self.left:
            left_height = self.left.height
        else:
            left_height = 0
        return right_height - left_height
            
    def rebalance(self):
        if not self.parent or not self.parent.parent or self.parent.parent.key is None:
            return # nothing to do
        x = self
        y = self.parent
        z = self.parent.parent
        assert abs(x.balance()) < 2
        assert abs(y.balance()) < 2
        assert abs(z.balance()) < 3
        if z.balance() == -2:
            if x.key == y.left.key:
                # Left left case.
                z.right_rotate()
            else:
                # Left right case:
                y.left_rotate()
                z.right_rotate()
        elif z.balance() == 2:
            if x.key == y.right.key:
                # Right right case.
                z.left_rotate()
            else:
                # Right left case.
                y.right_rotate()
                z.left_rotate()
        else:
            y.rebalance()

    def right_rotate(self):
        y = self
        x = self.left
        assert x
        a = x.left
        b = x.right
        c = y.right
        
        if y.parent.left.key == y.key:
            y.parent.left = x
        else:
            y.parent.right = x
        
        x.parent = y.parent
        x.left = a
        x.right = y
        
        y.parent = x
        y.left = b
        y.right = c
        
        if a:
            a.parent = x
        if b:
            b.parent = y
        if c:
            c.parent = y
            
        y.height = max(1, 1 + b.height if b else 0, 1 + c.height if c else 0)
        x.height = max(1 + y.height, 1 + a.height if a else 0)
        x.update_heights()
        
    def left_rotate(self):
        x = self
        y = self.right
        assert y
        a = x.left
        b = y.left
        c = y.right
        
        if x.parent.left.key == x.key:
            x.parent.left = y
        else:
            x.parent.right = y
        
        y.parent = x.parent
        y.left = x
        y.right = c
        
        x.parent = y
        x.left = a
        x.right = b
        
        if a:
            a.parent = x
        if b:
            b.parent = x
        if c:
            c.parent = y
            
        x.height = max(1, 1 + a.height if a else 0, 1 + b.height if b else 0)
        y.height = max(1 + x.height, 1 + c.height if c else 0)
        y.update_heights()

    # TODO: Balance after delete.
    def delete(self, key=None):
        """Deletes a node by key, raises KeyError if not found."""
        key = key or self.key
        
        # Node to delete is on the right.
        if self.key < key:
            if self.right:
                self.right.delete(key)
            else:
                raise KeyError('Key %r not found' % key)

        # Node to delete is on the left.
        elif self.key > key:
            if self.left:
                self.left.delete(key)
            else:
                raise KeyError('Key %r not found' % key)

        # This is the node to delete.
        else:
            # First case, no children. Just delete the node.
            if not self.right and not self.left:
                if self.parent.left and self.parent.left.key == self.key:
                    self.parent.left = None
                else:
                    assert self.parent.right and self.parent.right.key == self.key
                    self.parent.right = None

            # Second case, only a left child. Just replace the node with the child.
            elif not self.right:
                self._replace_in_parent(self.left)
            
            # Third case, only a right child. Just replace the node with the child.
            elif not self.left:
                self._replace_in_parent(self.right)
            
            # Fourth case, two children. We find the in-order successor, which is necessarily the right
            # child of its parent, and which might have up to one left child. If it has no child, we delete it.
            # If it has a left child, we replace it with its left child. And finally we replace the node to
            # delete with its in-order successor, which means replacing key and record while keeping left and
            # right children as well as parent.
            else:
                successor = self.right._min_element()
                successor._replace_in_parent(successor.left) # can be None
                self.key = successor.key
                self.record = successor.record

    def traverse(self, func):
        """Apply a callback to all nodes in-order."""
        if self.left:
            self.left.traverse(func)
        func(self.key, self.record)
        if self.right:
            self.right.traverse(func)
        
    def display(self, fmt):
        """Pretty-prints the BST."""
        lines, _, _, _ = self._display_aux(fmt)
        for line in lines:
            print(line)
            
    def _min_element(self):
        if self.left:
            return self.left._min_element()
        else:
            return self
            
    def _replace_in_parent(self, replace_with=None):
        if self.parent:
            if self.parent.left.key == self.key:
                self.parent.left = replace_with
            else:
                self.parent.right = replace_with
            if replace_with:
                replace_with.parent = self.parent
    
    def _display_aux(self, fmt):
        """Returns list of strings, width, height, and horizontal coordinate of the root."""
        # No child.
        if not self.right and not self.left:
            line = fmt(self)
            width = len(line)
            height = 1
            middle = width // 2
            return [line], width, height, middle
        
        # Only left child.
        if not self.right:
            lines, n, p, x = self.left._display_aux(fmt)
            s = fmt(self)
            u = len(s)
            first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s
            second_line = x * ' ' + '/' + (n - x - 1 + u) * ' '
            shifted_lines = [line + u * ' ' for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, n + u // 2
        
        # Only right child.
        if not self.left:
            lines, n, p, x = self.right._display_aux(fmt)
            s = fmt(self)
            u = len(s)
            first_line = s + x * '_' + (n - x) * ' '
            second_line = (u + x) * ' ' + '\\' + (n - x - 1) * ' '
            shifted_lines = [u * ' ' + line for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, u // 2
        
        # Two children.
        left, n, p, x = self.left._display_aux(fmt)
        right, m, q, y = self.right._display_aux(fmt)
        s = fmt(self)
        u = len(s)
        first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s + y * '_' + (m - y) * ' '
        second_line = x * ' ' + '/' + (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
        if p < q:
            left += [n * ' '] * (q - p)
        elif q < p:
            right += [m * ' '] * (p - q)
        zipped_lines = zip(left, right)
        lines = [first_line, second_line] + [a + u * ' ' + b for a, b in zipped_lines]
        return lines, n + m + u, max(p, q) + 2, n + u // 2
    
    ### The following is for priority queues.
    
    def find_min(self):
        if self.left:
            return self.left.find_min()
        return self
    
    def find_max(self):
        if self.right:
            return self.right.find_max()
        return self
    
    def delete_min(self):
        self.find_min().delete()
    
    def delete_max(self):
        self.find_max().delete()
    
    def pop_min(self):
        node = self.find_min()
        node.delete()
        return node
    
    def pop_max(self):
        node = self.find_max()
        node.delete()
        return node
    

class AvlTree(object):
    
    def __init__(self):
        self.root = AvlNode(key=None)
        
    def empty(self):
        return self.root.left is None
        
    def search(self, key):
        if self.empty():
            raise KeyError('Key %r not found', key)
        return self.root.left.search(key)
        
    def insert(self, key, record=None):
        if self.empty():
            self.root.left = AvlNode(key, record, parent=self.root)
        else:
            self.root.left.insert(key, record)
    
    def delete(self, key):
        if self.empty():
            raise KeyError('Key %r not found', key)
        self.root.left.delete(key)
    
    def display(self, fmt=None):
        if self.empty():
            print()
        else:
            if not fmt:
                fmt = lambda node: '%r' % node.key
            self.root.left.display(fmt)
    
    def traverse(self, func):
        if not self.empty():
            self.root.left.traverse(func)
            
    def find_min(self):
        if self.empty():
            raise ValueError('Empty BST')
        return self.root.left.find_min()
            
    def find_max(self):
        if self.empty():
            raise ValueError('Empty BST')
        return self.root.left.find_max()
            
    def delete_min(self):
        if self.empty():
            raise ValueError('Empty BST')
        self.root.left.delete_min()
            
    def delete_max(self):
        if self.empty():
            raise ValueError('Empty BST')
        self.root.left.delete_max()
            
    def pop_min(self):
        if self.empty():
            raise ValueError('Empty BST')
        return self.root.left.pop_min()
            
    def pop_max(self):
        if self.empty():
            raise ValueError('Empty BST')
        return self.root.left.pop_max()

In [662]:
b = AvlTree()
for i in range(15):
    print('\n=== Inserting %d ===' % i)
    b.insert(i)
    b.display(lambda node: '%r (%d)' % (node.key, node.height))


=== Inserting 0 ===
0 (1)

=== Inserting 1 ===
0 (2)__   
       \  
     1 (1)

=== Inserting 2 ===
   __1 (2)__   
  /         \  
0 (1)     2 (1)

=== Inserting 3 ===
   __1 (3)__        
  /         \       
0 (1)     2 (2)__   
                 \  
               3 (1)

=== Inserting 4 ===
   __1 (3)_______        
  /              \       
0 (1)        __3 (2)__   
            /         \  
          2 (1)     4 (1)

=== Inserting 5 ===
        _______3 (3)__        
       /              \       
   __1 (2)__        4 (2)__   
  /         \              \  
0 (1)     2 (1)          5 (1)

=== Inserting 6 ===
        _______3 (3)_______        
       /                   \       
   __1 (2)__           __5 (2)__   
  /         \         /         \  
0 (1)     2 (1)     4 (1)     6 (1)

=== Inserting 7 ===
        _______3 (4)_______             
       /                   \            
   __1 (2)__           __5 (3)__        
  /         \         /         \       
0 (1)     2

In [663]:
b = AvlTree()
for i in range(50):
    b.insert(i)
b.display()

                      ______________________________31_______________                     
                     /                                               \                    
        ____________15_______________                         ______39_______             
       /                             \                       /               \            
    ___7_____                 ______23_______             __35___         __43_______     
   /         \               /               \           /       \       /           \    
  _3_     __11___         __19___         __27___       33_     37_     41_       __47_   
 /   \   /       \       /       \       /       \     /   \   /   \   /   \     /     \  
 1   5   9_     13_     17_     21_     25_     29_   32  34  36  38  40  42    45_   48_ 
/ \ / \ /  \   /   \   /   \   /   \   /   \   /   \                           /   \     \
0 2 4 6 8 10  12  14  16  18  20  22  24  26  28  30                          44  46    49