Binary Search Trees, Traversals and Balancing in Python

Implement the solution and test it using example inputs.

In [6]:
class User:
    def __init__(self, username, name, email):
        self.username = username
        self.name = name
        self.email = email
        
    def __repr__(self):
        return "User(username='{}', name='{}', email='{}')".format(self.username, self.name, self.email)
    
    def __str__(self):
        return self.__repr__()

In [7]:
aakash = User('aakash', 'Aakash Rai', 'aakash@example.com')
biraj = User('biraj', 'Biraj Das', 'biraj@example.com')
hemanth = User('hemanth', 'Hemanth Jain', 'hemanth@example.com')
jadhesh = User('jadhesh', 'Jadhesh Verma', 'jadhesh@example.com')
siddhant = User('siddhant', 'Siddhant Sinha', 'siddhant@example.com')
sonaksh = User('sonaksh', 'Sonaksh Kumar', 'sonaksh@example.com')
vishal = User('vishal', 'Vishal Goel', 'vishal@example.com')

In [29]:
class UserDatabase:
    def __init__(self):
        self.users = []
    
    def insert(self, user):
        i = 0
        while i < len(self.users):
            # Find the first username greater than the new user's username
            if self.users[i].username > user.username:
                break
            i += 1
        self.users.insert(i, user)
    
    def find(self, username):
        for user in self.users:
            if user.username == username:
                return user
    
    def update(self, user):
        target = self.find(user.username)
        target.name, target.email = user.name, user.email
    
    def remove(self, username):
        target = self.find(username)
        if target:
            self.users.remove(target)
        
    def list_all(self):
        return self.users
        

In [30]:
database = UserDatabase()

In [31]:
database.insert(hemanth)
database.insert(aakash)
database.insert(siddhant)

In [32]:
user = database.find('siddhant')
user

User(username='siddhant', name='Siddhant U', email='siddhantu@example.com')

In [33]:
database.update(User(username='siddhant', name='Siddhant U', email='siddhantu@example.com'))

In [34]:
user = database.find('siddhant')
user

User(username='siddhant', name='Siddhant U', email='siddhantu@example.com')

In [35]:
database.list_all()

[User(username='aakash', name='Aakash Rai', email='aakash@example.com'),
 User(username='hemanth', name='Hemanth Jain', email='hemanth@example.com'),
 User(username='siddhant', name='Siddhant U', email='siddhantu@example.com')]

In [36]:
database.insert(biraj)

In [37]:
database.list_all()

[User(username='aakash', name='Aakash Rai', email='aakash@example.com'),
 User(username='biraj', name='Biraj Das', email='biraj@example.com'),
 User(username='hemanth', name='Hemanth Jain', email='hemanth@example.com'),
 User(username='siddhant', name='Siddhant U', email='siddhantu@example.com')]

In [38]:
database.remove('siddhant')

In [39]:
database.list_all()

[User(username='aakash', name='Aakash Rai', email='aakash@example.com'),
 User(username='biraj', name='Biraj Das', email='biraj@example.com'),
 User(username='hemanth', name='Hemanth Jain', email='hemanth@example.com')]

In [64]:
class TreeNode():
    def __init__(self, key):
        self.key, self.left, self.right = key, None, None
    
    def height(self):
        if self is None:
            return 0
        return 1 + max(TreeNode.height(self.left), TreeNode.height(self.right))
    
    def size(self):
        if self is None:
            return 0
        return 1 + TreeNode.size(self.left) + TreeNode.size(self.right)

    def traverse_in_order(self):
        if self is None: 
            return []
        return (TreeNode.traverse_in_order(self.left) + 
                [self.key] + 
                TreeNode.traverse_in_order(self.right))
    
    def traverse_pre_order(self):
        if self is None: 
            return []
        return ([self.key] + TreeNode.traverse_in_order(self.left) + 
                TreeNode.traverse_in_order(self.right))
    
    def traverse_post_order(self):
        if self is None: 
            return []
        return (TreeNode.traverse_in_order(self.left) + 
                TreeNode.traverse_in_order(self.right) + [self.key])
    
    def display_keys(self, space='\t', level=0):
        # If the node is empty
        if self is None:
            print(space*level + '∅')
            return   

        # If the node is a leaf 
        if self.left is None and self.right is None:
            print(space*level + str(self.key))
            return

        # If the node has children
        TreeNode.display_keys(self.right, space, level+1)
        print(space*level + str(self.key))
        TreeNode.display_keys(self.left,space, level+1)    
    
    def to_tuple(self):
        if self is None:
            return None
        if self.left is None and self.right is None:
            return self.key
        return TreeNode.to_tuple(self.left),  self.key, TreeNode.to_tuple(self.right)
    
    def __str__(self):
        return "BinaryTree <{}>".format(self.to_tuple())
    
    def __repr__(self):
        return "BinaryTree <{}>".format(self.to_tuple())
    
    @staticmethod       
    def parse_tuple(data):
        if data is None:
            node = None
        elif isinstance(data, tuple) and len(data) == 3:
            node = TreeNode(data[1])
            node.left = TreeNode.parse_tuple(data[0])
            node.right = TreeNode.parse_tuple(data[2])
        else:
            node = TreeNode(data)
        return node
    
    @staticmethod
    def insert(node, key):
        if node is None:
            node = TreeNode(key)
        elif key < node.key:
            node.left = TreeNode.insert(node.left, key)
        elif key > node.key:
            node.right = TreeNode.insert(node.right, key)
        return node
    
    @staticmethod
    def find(node, key):
        if node is None:
            return None
        elif key == node.key:
            return node
        elif key < node.key:
            return TreeNode.find(node.left, key)
        elif key > node.key:
            return TreeNode.find(node.right, key)
        
    def is_balanced(self):
        if self is None:
            return True, 0
        balanced_l, height_l = TreeNode.is_balanced(self.left)
        balanced_r, height_r = TreeNode.is_balanced(self.right)
        balanced = balanced_l and balanced_r and abs(height_l - height_r) <=1
        height = 1 + max(height_l, height_r)
        return balanced, height
    
    def list_all(self):
        if self is None:
            return []
        return TreeNode.list_all(self.left) + [(self.key)] + TreeNode.list_all(self.right)
    
    @staticmethod
    def make_balanced_bst(data, lo=0, hi=None):
        if hi is None:
            hi = len(data) - 1
        if lo > hi:
            return None
        
        mid = (lo + hi) // 2
        key = data[mid]

        root = TreeNode(key)
        root.left = TreeNode.make_balanced_bst(data, lo, mid-1)
        root.right = TreeNode.make_balanced_bst(data, mid+1, hi)
        
        return root
    

In [65]:
tree_tuple = (1, 2, ((None, 3, 4), 5, (6, 7, 8)))


In [66]:
tree = TreeNode.parse_tuple(tree_tuple)
tree

BinaryTree <(1, 2, ((None, 3, 4), 5, (6, 7, 8)))>

In [70]:
tree.display_keys('  ')

        9
      8
        ∅
    7
      6
  5
      4
    3
      ∅
2
  1


In [71]:
tree.to_tuple()

(1, 2, ((None, 3, 4), 5, (6, 7, (None, 8, 9))))

In [72]:
TreeNode(5).to_tuple()

5

In [73]:
TreeNode.insert(tree, 9).display_keys('  ')

        9
      8
        ∅
    7
      6
  5
      4
    3
      ∅
2
  1


In [74]:
TreeNode.find(tree, 10)

In [75]:
tree.is_balanced()

(False, 5)

In [76]:
tree.height()

5

In [77]:
tree.list_all()

[1, 2, 3, 4, 5, 6, 7, 8, 9]

In [78]:
TreeNode.make_balanced_bst(tree.list_all())

BinaryTree <((1, 2, (None, 3, 4)), 5, (6, 7, (None, 8, 9)))>

In [80]:
tree.traverse_post_order(), tree.traverse_pre_order(), tree.traverse_in_order() 

([1, 3, 4, 5, 6, 7, 8, 9, 2],
 [2, 1, 3, 4, 5, 6, 7, 8, 9],
 [1, 2, 3, 4, 5, 6, 7, 8, 9])