In [127]:
import numpy as np


training_data = np.array([[1,0], [2,5], [3,1], [4,7], [6,2], [7,4], [10,6]])


class KDTree:
    def __init__(self):
        self._tree = None
    
    def learn(self, data):
        def create_tree(partial_data, dim, max_dim):            
            if len(partial_data) == 0:
                return partial_data
            elif len(partial_data) == 1:
                return partial_data[0]
            
            partial_data = np.array(sorted(partial_data, key=lambda x: x[dim]))

            mid_index = int(np.ceil(len(partial_data) / 2.))
            
            left = partial_data[:mid_index]
            mid = partial_data[mid_index]
            right = partial_data[mid_index + 1:]

            next_dim = dim + 1 if dim < max_dim else 0
            
            node = [mid, 
                    [create_tree(left, next_dim, max_dim) if len(left.shape) == 2 else left, 
                     create_tree(right, next_dim, max_dim)] if len(right.shape) == 2 else right]
            return node
        
        self._tree = create_tree(data, 0, data.shape[1] - 1)
        
    def print_tree(self):
        def _print_tree(node, depth):
            if isinstance(node, np.ndarray):
                print('|' + ''.join(['-'] * depth) + str(node))                
            else:
                if isinstance(node[0], np.ndarray):
                    _print_tree(node[0], depth)
                    _print_tree(node[1], depth if isinstance(node[1], np.ndarray) else depth + 1)
                else:
                    _print_tree(node[0], depth)
                    _print_tree(node[1], depth)
        _print_tree(self._tree, 1)
        
        
kd_tree = KDTree()
kd_tree.learn(sample_data)
kd_tree.print_tree()

|-[6 2]
|--[2 5]
|---[3 1]
|----[1 0]
|----[]
|---[4 7]
|--[10  6]
|---[7 4]
|---[]


[array([6, 2]),
 [[array([2, 5]),
   [[array([3, 1]), [array([1, 0]), array([], shape=(0, 2), dtype=int64)]],
    array([4, 7])]],
  [array([10,  6]), [array([7, 4]), array([], shape=(0, 2), dtype=int64)]]]]