In [50]:
class Node:
    def __init__(self, val):
        assert(isinstance(val, (int, float))), 'Invalid value'
        self.val = val
        self.left_child = None
        self.right_child = None
        self.parent = None
        self.height = None
        self.level = None
        
    def __repr__(self):
        return f'<{self.val}>'

In [51]:
node1 = Node(5)
node2 = Node(3)
node3 = Node(7)
node4 = Node(2)

In [52]:
node1.left_child = node2
node2.parent = node1
node1.right_child = node3
node3.parent = node1
node4.parent = node2
node2.left_child = node4

In [53]:
import math
from collections import defaultdict

class Tree:

    def __init__(self, root=None):
        if root is None:
            self.root = None
            self.height = 0
        else:
            self.root = root
            self.height = self.get_height()
    
    @staticmethod
    def _set_level(node):  # set level property for a specific node and in line parents
        if not node.parent:
            node.level = 1
        else:
            _set_level(node.parent)
            node.level = node.parent.level + 1
            
    @staticmethod
    def _get_height(node):  # get a specfic node height and set height to node and all offspring nodes
        left_height = 1
        right_height = 1
        if node.left_child:
            left_height = _get_height(node.left_child) + 1
        if node.right_child:
            right_height = _get_height(node.right_child) + 1
        height = max(left_height, right_height)
        node.height = height
        return height
    
    def get_height(self):  # get total height and set height to all nodes in a tree
        if not self.root:
            return 0
        else:
            return self._get_height(self.root)
        
    @staticmethod
    def _get_nodes(output: defaultdict(list), node):  # get defaultdict with node height as key, list of nodes as value
        assert(isinstance(node, Node) and node.val), 'Invalid node'
        assert(isinstance(output, defaultdict)), 'output must be of type defaultdict(list)'
        if not node.level:
            _set_level(node)  # This will set level for every node in a tree
        output[node.level].append(node)
        if node.left_child:
            _get_nodes(output, node.left_child)
        if node.right_child:
            _get_nodes(output, node.right_child)
        
    def get_nodes(self):  # return a dict of level:list of nodes as key:value pair
        output = defaultdict(list)
        if self.root:
            self._get_nodes(output, self.root)
        return dict(output)
    
    @staticmethod
    def get_pos(node):  # set pos property in node
        if node.parent is None:
            node.ho_pos = 2 * node.height  # horizontal position for root
        else:  # horizontal position for node relative to parent node
            if node is node.parent.left_child:
                node.ho_pos = node.parent.ho_pos - 2
            else:
                node.ho_pos = node.parent.ho_pos + 2
        
    def draw_tree(self):
        nodes = self.get_nodes()
        content = '\n'
        while nodes:
            cur_level = min(nodes.keys())
            for node in nodes[cur_level]:
                self.get_pos(node)
            cur_line = sorted(nodes[cur_level], key=lambda n: n.ho_pos)  # sorted line of nodes by horizontal position
            cur_line_str_lst = list(' ' * (cur_line[-1].ho_pos + 1)) # string does not support item assignment
            below_line_mark_lst = cur_line_str_lst.copy()
            mark_count = 0
            for node in cur_line:
                cur_line_str_lst[node.ho_pos] = str(node.val)
                below_line_mark_lst[node.ho_pos-1 - mark_count*2] = '/ \\'  # '/ \' takes three space      
                mark_count += 1
            cur_line_str = ''.join(cur_line_str_lst) + '\n'
            below_line_mark = ''.join(below_line_mark_lst) + '\n'
            content += cur_line_str
            content += below_line_mark
            del nodes[cur_level]
        print(content)
    
    def _insert(self, root, node):
        if node.val <= root.val:
            if not root.left_child:
                root.left_child = node
                node.parent = root
            else:
                self._insert(root.left_child, node)
        else:
            if not root.right_child:
                root.right_child = node
                node.parent = root
            else:
                self._insert(root.right_child, node)
                
    def insert(self, node):
        if not self.root:
            self.root = node
        else:
            self._insert(self.root, node)
            
      

In [54]:
tree = Tree(node1)

In [55]:
tree.draw_tree()


      5
     / \ 
    3   7
   / \ / \   
  2
 / \ 



In [None]:
# https://github.com/bfaure/Python3_Data_Structures/blob/master/AVL_Tree/main.py