# Binary Search Trees (BSTs)
----

In [1]:
## Define some function useful for testing
import random

## generate an array of n random integers up to b
def get_random_array(n, b=50):
    return [random.randint(0, b) for _ in range(n)]

Hashing-based data structures are efficient solutions to index a set of keys providing three operations:
- Insert a new key in the set
- Delete a key from the set
- Search a key in the set (and return its associated value.

Binary Search Tree (BST) extends the set of operations with more ones.

- Min/max keys in the set
- Predecessor of a value, i.e., largest key in the set which is smaller than the given one
- Successor of a value, i.e., smallest key in the set which is greater than the given one

Implementing the above operations gives a **sorted map** (or ordered map).


Notice that if the set would be **static** (i.e., no insert and delete) the problem can be easily solved with 
binary search on a sorted array. This is the goal of the first exercise. 

---
### Exercise: Static sorted map
Complete and test the implementation below. You have to use binary search to solve predecessor and successor queries on a sorted array.

In [10]:
class StaticSortedMap:
    def __init__(self, A):
        self.sorted_map = A[:] # copy input array
        
    def minimo(self):
        return self.sorted_map[0]
    
    def massimo(self):
        return self.sorted_map[-1]
    
    def search(self, key):
        p = 0
        r = len(self.sorted_map)-1
        
        def binary_search(p, r, key):
            if p>r:
                return False, -1
            if p==r:
                if self.sorted_map[p]==key:
                    return True, p
                else:
                    if key > self.sorted_map[p]:
                        return False, p+1
                    elif key < self.sorted_map[p]:
                        return False, p-1
            mid = (p+r)//2
            if self.sorted_map[mid] == key:
                return True, mid
            if self.sorted_map[mid] < key:
                return(binary_search(mid+1, r, key))
            else:
                return(binary_search(p, mid-1, key))
        return binary_search(0, len(self.sorted_map)-1, key)

    def predecessor(self, key):
        esiste, pos = self.search(key)
        if pos == 0:
            return None #predecessor non esiste
        if esiste == True:
            return pos-1, self.sorted_map[pos-1]
        else:
            return pos-1, self.sorted_map[pos-1]
        
    def successor(self, key):
        esiste, pos = self.search(key)
        if pos >= len(self.sorted_map)-1:
            return None #successor non esiste
        if esiste == True:
            return pos+1, self.sorted_map[pos+1]
        else:
            return pos, self.sorted_map[pos]

In [11]:
## Test your implementation here
A = [4, 8, 12, 91, 97, 102, 103]
array = StaticSortedMap(A)

print("Min: ", array.minimo())
print("Max: ", array.massimo())
print("Search: ", array.search(103))
print("Predecessor: ", array.predecessor(91))
print("Successor: ", array.successor(91))
print()
print(array.predecessor(103))
print(array.successor(103))
print()
print(array.predecessor(4))
print(array.successor(4))
print()
print(array.predecessor(7))
print(array.successor(7))
print()
print(array.predecessor(200))
print(array.successor(200))

Min:  4
Max:  103
Search:  (True, 6)
Predecessor:  (2, 12)
Successor:  (4, 97)

(5, 102)
None

None
(1, 8)

(0, 4)
(1, 8)

(6, 103)
None


---
## Sorted map with Binary Search Tree

In [4]:
class BinarySearchTree:
    # This is a Node class that is internal to the BinarySearchTree class
    class __Node:
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
            
        def getVal(self): 
            return self.val

        def setVal(self,newval): 
            self.val = newval
            
        def getLeft(self): 
            return self.left
        
        def getRight(self): 
            return self.right
        
        def setLeft(self,newleft): 
            self.left = newleft
        
        def setRight(self,newright): 
            self.right = newright
            
        # This does an inorder traversal of the nodes of the tree yielding all the values. 
        # In this way, we get the values in ascending order.       
        def __iter__(self):
            if self.left != None:
                for elem in self.left: 
                    yield elem
            yield self.val
            if self.right != None:
                for elem in self.right:
                    yield elem
                    
    # Below methods of the BinarySearchTree class.
    def __init__(self): 
        self.root = None
         
    def insert(self, val):   
        # The __insert function is recursive and is not a passed a self parameter. It is a 
        # static function (not a method of the class) but is hidden inside the insert
        # function so users of the class will not know it exists.
        def __insert(root, val): 
            if root == None:
                return BinarySearchTree.__Node(val)
            if val < root.getVal(): 
                root.setLeft(__insert(root.getLeft(), val))
            else: 
                root.setRight(__insert(root.getRight(), val))
            return root
        self.root = __insert(self.root, val)

    
    def search(self, val):
        node = self.root
        def __binary_search(node, val):
            if node != None:
                if val == node.getVal():
                    return True
                elif val < node.getVal():
                    return __binary_search(node.left, val)
                elif val > node.getVal():
                    return __binary_search(node.right, val)
            return False
        return __binary_search(node, val)
        

In [5]:
from random import seed

In [6]:
seed(42)

a = get_random_array(10)

bst = BinarySearchTree()

for x in a: 
    bst.insert(x)

print([x for x in bst.root][:10])
    
assert [x for x in bst.root] == sorted(a), "FAIL insert!"


## It works with strings as well
stringhe = ["ciao", "aaa", "zzz", "zzzW"]

bst_strings = BinarySearchTree()

for string in stringhe:
    bst_strings.insert(string)

print([x for x in bst_strings.root])

assert [x for x in bst_strings.root] == sorted(stringhe), "FAIL!"

[1, 6, 7, 8, 14, 15, 17, 40, 47, 47]
['aaa', 'ciao', 'zzz', 'zzzW']


### Exercise: 
Extend the previous implementation to support **search(x)** operation. Test your implementation.

In [7]:
print(a)

[40, 7, 1, 47, 17, 15, 14, 8, 47, 6]


In [8]:
# Your implementation goes here
key_true=15
print(bst.search(key_true))

key_false=10
print(bst.search(key_false))

True
False


In [9]:
# Test your implementation here
assert bst.search(key_true) == (key_true in a), "Fail search!" 
assert bst.search(key_false) == (key_false in a), "Fail search!" 