In [None]:
import inquirer
import networkx as nx
from matplotlib import pyplot as plt
from networkx.drawing.nx_agraph import graphviz_layout

class Node:
    def __init__(self, data):
        self.value = data
        self.prev = None
        self.left = None
        self.right = None

class BST:
    def __init__(self):
        self.root = None

    def find(self, data):
        def _step(current):
            if current is None:
                raise ValueError(str(data) + " is not in the tree.")
            if data == current.value: # data found
                return current
            if data < current.value: # recursively search left.
                return _step(current.left)
            else: # recursively search right.
                return _step(current.right)
        return _step(self.root)

    def insert(self, data):
        def _check(current):
            if current is None:
                current = Node(data)
                self.root=current
                return
            if data == current.value:
                raise ValueError(str(data) + " is already in the tree.")
            if data < current.value:
                if current.left == None:
                    current.left = Node(data)
                    current.left.prev = current
                else:
                    return _check(current.left)
            elif current.right == None:
                current.right = Node(data)
                current.right.prev = current
            else:
                return _check(current.right)
        return _check(self.root)

    def remove(self, data):
        temp = find(data)
        #target is leaf
        if temp.left == None and temp.right == None:
            if temp.prev == None:
                self.root = None
                return
            if temp.prev.left == temp:
                temp.prev.left = None
                return
            else:
                temp.prev.right = None
                return

       #two childrens
        if temp.left != None and temp.right != None:
            prod = temp.left
            while prod.right != None:
                prod = prod.right
            temp.value = prod.value
            remove(prod)
            return

        #one children
        if temp.prev == self.root:
            temp.prev = None
            temp = self.root
            return
        if temp.prev.left == temp:
            temp.prev.left = temp.right
            return
        if temp.prev.right == temp:
            temp.prev.right = temp.right
            return

    def __str__(self):
        """String representation: a hierarchical view of the BST.
        Example:
            (3)
            / \
          (2) (5)
          /   / \
        (1) (4) (6)

        '[3]
        [2, 5]
        [1, 4, 6]'
        The nodes of the BST are printed by depth levels. Edges and empty nodes are not printed.
        """
        if self.root is None:
            return "[]"
        out, current_level = [], [self.root]
        while current_level:
            next_level, values = [], []
            for node in current_level:
                values.append(node.value)
                for child in [node.left, node.right]:
                    if child is not None:
                        next_level.append(child)
            out.append(values)
            current_level = next_level
        return "\n".join([str(x) for x in out])

    def draw(self):
        """Use NetworkX and Matplotlib to visualize the tree."""
        if self.root is None:
            return
        # Build the directed graph.
        G = nx.DiGraph()
        G.add_node(self.root.value)
        nodes = [self.root]
        while nodes:
            current = nodes.pop(0)
            for child in [current.left, current.right]:
                if child is not None:
                    G.add_edge(current.value, child.value)
                    nodes.append(child)
        # Plot the graph. This requires graphviz_layout (pygraphviz).
        nx.draw(G, pos=graphviz_layout(G, prog="dot"), arrows=True,
                with_labels=True, node_color="C1", font_size=8)
        plt.show()

import random
my_tree = BST()
my_tree.insert(10)
my_tree.insert(15)
my_tree.insert(5)
my_tree.insert(8)
my_tree.insert(4)
my_tree.draw()
