In [1]:
import numpy as np
import ctypes as ct
import sys

In [2]:
class TreeNode(ct.Structure):
    pass

TreeNode._fields_ = [
    ("key", ct.c_int),
    ("left", ct.POINTER(TreeNode)),
    ("right", ct.POINTER(TreeNode))
]

class BST_ctype(ct.Structure):
    _fields_ = [
        ("root", ct.POINTER(TreeNode))
    ]

class BST:
    def __init__(self):
        self.bst = BST_ctype(root = ct.POINTER(TreeNode)(), dead = ct.POINTER(TreeNode)(), a_count = 0, tot_count = 1, cfl_min_dt = sys.float_info.max)
        
    def insert(self, new_node):
        if not self.bst.root:
            self.bst.root = ct.pointer(new_node)
        else:
            self._insert_recursive(self.bst.root, new_node)

    def _insert_recursive(self, node, new_node):
        if new_node.key < node.contents.key:
            if not node.contents.left:
                node.contents.left = ct.pointer(new_node)
            else:
                self._insert_recursive(node.contents.left, new_node)
        elif new_node.key > node.contents.key:
            if not node.contents.right:
                node.contents.right = ct.pointer(new_node)
            else:
                self._insert_recursive(node.contents.right, new_node)

    def search(self, key):
        return bool(self._search_recursive(self.bst.root, key))

    def _search_recursive(self, node, key):
        if not node or node.contents.key == key:
            return node
        if key < node.contents.key:
            return self._search_recursive(node.contents.left, key)
        else:
            return self._search_recursive(node.contents.right, key)
        
    def _min_value_node(self, node):
        current = node
        while(current.contents.left):
            current = current.contents.left
        return current
    
    def delete(self, key):
        if self.bst.root:
            self.bst.root = self._delete_recursive(self.bst.root, key)
    
    def _delete_recursive(self, node, key):
        if not node: 
            return node
        
        if key < node.contents.key:
            node.contents.left = self._delete_recursive(node.contents.left, key)
        elif key > node.contents.key: 
            node.contents.right = self._delete_recursive(node.contents.right, key)
        else:
            if not node.contents.left:
                return node.contents.right
            elif not node.contents.left:
                return node.contents.left
            
            temp = self._min_value_node(node.contents.right)
            node.contents.key = temp.contents.key
            node.contents.right = self._delete_recursive(node.contents.right, temp.contents.key)
        
        return node
         
    def _get_height(self, node):
        if not node:
            return 0 
        
        return 1 + max(self._get_height(node.contents.left), self._get_height(node.contents.right))
    
    def _get_difference(self, node):
        l_height = self._get_height(node.contents.left)
        r_height = self._get_height(node.contents.right)
        return l_height - r_height
    
    def _rr_rotate(self, node):
        t = node.contents.right 
        node.contents.right = t.contents.left
        t.contents.left = node
        return t
    
    def _ll_rotate(self, node):
        t = node.contents.left
        node.contents.left = t.contents.right
        t.contents.right = node
        return t
    
    def _lr_rotate(self, node):
        t = node.contents.left
        node.contents.left = self._rr_rotate(t)
        return self._ll_rotate(node)
    
    def  _rl_rotate(self, node):
        t = node.contents.right
        node.contents.right = self._ll_rotate(t)
        return self._rr_rotate(node)
    
    def balance(self):
        node = self.bst.root
        bal_factor = self._get_difference(node)
        while(abs(bal_factor) > 1):
            if bal_factor > 1:
                if self._get_difference(node.contents.left) > 0:
                    node = self._ll_rotate(node)
                else:
                    node = self._lr_rotate(node)
            elif bal_factor < -1:
                if self._get_difference(node.contents.right) > 0:
                    node = self._rl_rotate(node)
                else:
                    node = self._rr_rotate(node)
            bal_factor = self._get_difference(node)
        self.bst.root = node

    def print_tree(self):
        if self.bst.root:
            self._print_tree_recursive(self.bst.root, "", True)
        else:
            print("BST is empty")
   
    def _print_tree_recursive(self, node, indent, last):
        if node:
            print(indent, end="")
            if indent == "":
                print("ROOT----", end="")
                indent += "        "
            else:
                if last:
                    print("R----", end="")
                    indent += "     "
                else:
                    print("L----", end="")
                    indent += "|    "
            print(node.contents.key)
            self._print_tree_recursive(node.contents.left, indent, False)
            self._print_tree_recursive(node.contents.right, indent, True)

In [20]:
import random

root = TreeNode(key=0)
P = BST()
P.insert(root)

# print(P.bst.root.contents.left.contents)
# print(P.bst.root.contents.right.contents)

for i in range(2):
    P.insert(TreeNode(key=i+1))

P.print_tree()

P.balance()

P.print_tree()

ROOT----0
        R----1
             R----2


ValueError: NULL pointer access

In [3]:
import ctypes as ct

class TreeNode(ct.Structure):
    pass

TreeNode._fields_ = [
    ("key", ct.c_int),
    ("left", ct.POINTER(TreeNode)),
    ("right", ct.POINTER(TreeNode))
]

node_0 = TreeNode(key = 0)
node_1 = TreeNode(key = 1)
node_2 = TreeNode(key = 2)

node_0.right = ct.pointer(node_1)
node_1.right = ct.pointer(node_2)

node = ct.pointer(node_0)

t = node.contents.right
t_copy = ct.POINTER(TreeNode)(t.contents)
node.contents.right = t.contents.left
t_copy.contents.left = node

print(t_copy.contents.key)
print(t_copy.contents.left.contents.key)
print(t_copy.contents.right.contents.key)



1
0
2
