In [1]:
class BplusTreeLeafPage:
    def __init__(self, page_id):
        self.page_id = page_id
        self.key_array = [0] * 300
        self.value_array = [0] * 300
        self.page_type = 'leaf'
        self.is_leaf= True
        self.size = 0
        self.max_size = 300
        
    def insert(self, key, value):
        self.size += 1
        self.key_array[self.size - 1] = key
        self.value_array[self.size - 1] = value

class BplusTreeInternalPage:
    def __init__(self, page_id):
        self.page_id = page_id
        self.key_array = [0] * 300
        self.pageId_array = [0] * 300
        self.page_type = 'internal'
        self.is_leaf= False

class BplusTreeHeaderPage:
    def __init__(self, root_page_id):
        self.page_id = 0
        self.root_page_id = root_page_id # this is the root page ID of the Bplus Tree

In [2]:
class Bpm:
    def __init__(self):
        self.max_pages = 100
        self.counter = 0 
        self.page_dict = {} # this will keep the map of pageids to the page objects. 
        
    def get_new_page_id(self):
        self.counter+= 1
        self.page_dict[self.counter] = None 
        return self.counter
    
    def set_page(self, pageid, page):
        self.page_dict[pageid] = page
    
    def get_page(self, page_id):
        return self.page_dict[page_id]

In [3]:
INVALID_PAGE_ID = -1

class BPlusTree:
    def __init__(self, index_name, bpm, leaf_max_size, internal_max_size, header_page_id):
        self.index_name = index_name
        self.bpm = bpm
        self.leaf_max_size = leaf_max_size
        self.internal_max_size = internal_max_size
        self.header_page_id = header_page_id
        self.header_page = BplusTreeHeaderPage(INVALID_PAGE_ID)
        self.parent_map = {}  # child_page_id -> parent_page_id

    def is_empty(self) -> bool:
        return self.header_page.root_page_id == INVALID_PAGE_ID

    def insert(self, key, value):
        if self.is_empty():
            # Create root as a leaf page
            root_page_id = self.bpm.get_new_page_id()
            root = BplusTreeLeafPage(root_page_id)
            self.bpm.set_page(root_page_id, root)
            self.header_page.root_page_id = root_page_id

        # Locate target leaf
        leaf = self.find_leaf(key)

        if self.is_safe_to_insert(leaf):
            leaf.insert(key, value)
        else:
            # Split leaf and insert
            new_leaf, promoted_key = self.split_leaf(leaf)
            if key < promoted_key:
                leaf.insert(key, value)
            else:
                new_leaf.insert(key, value)
            self.insert_into_parent(leaf, promoted_key, new_leaf)

    def find_leaf(self, key):
        # <Optimize>binary search optimization here #
        page = self.bpm.get_page(self.header_page.root_page_id)
        while not page.is_leaf:
            i = 0
            while i < page.size and key >= page.key_array[i]:
                i += 1
            child_id = page.pageId_array[i]
            page = self.bpm.get_page(child_id)
        return page

    def is_safe_to_insert(self, page):
        return page.size < (self.leaf_max_size if page.is_leaf else self.internal_max_size)

    def split_leaf(self, leaf):
        new_leaf_id = self.bpm.get_new_page_id()
        new_leaf = BplusTreeLeafPage(new_leaf_id)
        mid = self.leaf_max_size // 2

        for i in range(mid, leaf.size):
            new_leaf.key_array[i - mid] = leaf.key_array[i]
            new_leaf.value_array[i - mid] = leaf.value_array[i]
            leaf.key_array[i] = 0
            leaf.value_array[i] = 0


        new_leaf.size = leaf.size - mid
        leaf.size = mid

        promoted_key = new_leaf.key_array[0]

        self.bpm.set_page(new_leaf_id, new_leaf)
        self.parent_map[new_leaf_id] = self.parent_map.get(leaf.page_id, None)
        return new_leaf, promoted_key

    def insert_into_parent(self, left, key, right):
        parent_id = self.parent_map.get(left.page_id)
        
        if parent_id is None:
            # No parent exists — create a new root internal node
            root_id = self.bpm.get_new_page_id()
            root = BplusTreeInternalPage(root_id)
            root.key_array[0] = key
            root.pageId_array[0] = left.page_id
            root.pageId_array[1] = right.page_id
            root.size = 1
            self.bpm.set_page(root_id, root)
            self.header_page.root_page_id = root_id
            self.parent_map[left.page_id] = root_id
            self.parent_map[right.page_id] = root_id
            return

        parent = self.bpm.get_page(parent_id)

        # ✅ Pre-check: If inserting would overflow the internal node, split first
        if not self.is_safe_to_insert(parent):
            new_internal, promoted = self.split_internal(parent)
            self.insert_into_parent(parent, promoted, new_internal)
            # Retry insert into the appropriate (possibly new) parent after split
            parent_id = self.parent_map.get(left.page_id)
            parent = self.bpm.get_page(parent_id)

        # Now it's safe to insert the new key + right pointer
        i = parent.size
        while i > 0 and key < parent.key_array[i - 1]:
            parent.key_array[i] = parent.key_array[i - 1]
            parent.pageId_array[i + 1] = parent.pageId_array[i]
            i -= 1

        parent.key_array[i] = key
        parent.pageId_array[i + 1] = right.page_id
        parent.size += 1

        self.parent_map[right.page_id] = parent.page_id


    def split_internal(self, node):
        new_internal_id = self.bpm.get_new_page_id()
        new_internal = BplusTreeInternalPage(new_internal_id)
        mid = self.internal_max_size // 2

        promoted = node.key_array[mid]

        new_internal.size = node.size - mid - 1
        for i in range(new_internal.size):
            new_internal.key_array[i] = node.key_array[mid + 1 + i]
            new_internal.pageId_array[i] = node.pageId_array[mid + 1 + i]
            self.parent_map[new_internal.pageId_array[i]] = new_internal_id

        new_internal.pageId_array[new_internal.size] = node.pageId_array[node.size]
        self.parent_map[new_internal.pageId_array[new_internal.size]] = new_internal_id

        node.size = mid

        self.bpm.set_page(new_internal_id, new_internal)
        return new_internal, promoted

In [4]:
# ---------- Test Code ----------

if __name__ == "__main__":
    bpm_test = Bpm()
    bplus_tree = BPlusTree("user_index", bpm_test, leaf_max_size=4, internal_max_size=4, header_page_id=1)

    for key in range(1, 100):  # Trigger multiple splits and internal propagation
        bplus_tree.insert(key, key + 100)

    print("\n--- Tree Pages ---")
    def print_tree(bptree):
        def dfs(page_id, depth):
            page = bptree.bpm.get_page(page_id)
            indent = "    " * depth
            if page.is_leaf:
                keys = page.key_array[:page.size]
                print(f"{indent}Leaf[{page_id}]: {keys}")
            else:
                keys = page.key_array[:page.size]
                children = page.pageId_array[:page.size + 1]
                print(f"{indent}Internal[{page_id}]: Keys={keys}")
                for child_id in children:
                    if child_id != 0:  # 0 indicates unused/empty pointer
                        dfs(child_id, depth + 1)

        print("\nB+ Tree Structure:")
        root_id = bptree.header_page.root_page_id
        dfs(root_id, 0)
    print_tree(bplus_tree)


--- Tree Pages ---

B+ Tree Structure:
Internal[27]: Keys=[19, 37, 55, 73]
    Internal[9]: Keys=[7, 13]
        Internal[3]: Keys=[3, 5]
            Leaf[1]: [1, 2]
            Leaf[2]: [3, 4]
            Leaf[4]: [5, 6]
        Internal[8]: Keys=[9, 11]
            Leaf[5]: [7, 8]
            Leaf[6]: [9, 10]
            Leaf[7]: [11, 12]
        Internal[13]: Keys=[15, 17]
            Leaf[10]: [13, 14]
            Leaf[11]: [15, 16]
            Leaf[12]: [17, 18]
    Internal[26]: Keys=[25, 31]
        Internal[17]: Keys=[21, 23]
            Leaf[14]: [19, 20]
            Leaf[15]: [21, 22]
            Leaf[16]: [23, 24]
        Internal[21]: Keys=[27, 29]
            Leaf[18]: [25, 26]
            Leaf[19]: [27, 28]
            Leaf[20]: [29, 30]
        Internal[25]: Keys=[33, 35]
            Leaf[22]: [31, 32]
            Leaf[23]: [33, 34]
            Leaf[24]: [35, 36]
    Internal[40]: Keys=[43, 49]
        Internal[31]: Keys=[39, 41]
            Leaf[28]: [37, 38]
         

In [19]:
key_arr = [None, 3, 5, 7]
pageid_arr = [10, 20, 30, 40]
target = 3

def binary_search(key_arr, pageid_arr, target_key):
    left = 0
    right = len(key_arr) - 1 
    
    #edge case
    if target_key < key_arr[1]:
        return pageid_arr[0]
    # search the key until both of the search pointers collide. 
    # left cannot be greater than right 
    while left <= right:
        mid = left + (right - left) // 2
        print(mid)        
        
        if target_key == key_arr[mid]:
            return pageid_arr[mid]
        elif target_key < key_arr[mid]:
            right = mid - 1
        elif target_key > key_arr[mid]:
            left = mid + 1
        else:
            pass
        
        #print(left, right) 
    return pageid_arr[left-1] 

print("pageid = {}".format(binary_search(key_arr, pageid_arr, target)))

1
pageid = 20
