# KD Tree: K-Dimension Tree

In [31]:
class Node:
    def __init__(self, point, k=2):
        '''
        point: list of dimensions [0, 1, 2, ..., n]
        '''
        self.point = point
        self.left = None
        self.right = None
        self.k = k
        
    @staticmethod
    def newNode(point):
        '''
        return the first root
        '''
        return Node(point)
        
    def insertNode(self, root, point, depth=0):
        '''
        root: the current node to consider the space of point
        point: data point, ex: point[2,3,5] - 3-D Point
        depth: use to calculate the dimension
        '''
        if not root:
            return self.newNode(point)
        cd = depth % k # cd: current dimension
        if point[cd] < root.point[cd]:
            root.left = self.insertNode(root.left, point, depth + 1)
        else:
            root.right = self.insertNode(root.right, point, depth + 1)
        return root
        
    def areSamePoints(self, point1, point2):
        for i in range (k):
            if point1[i] != point2[i]:
                return False
        return True
        
    def searchNode(self, root, point, depth=0):
        if not root:
            return None
        if self.areSamePoints(root.point, point):
            return root
        cd = depth % k
        if point[cd] < root.point[cd]:
            return self.searchNode(root.left, point, depth + 1)
        return self.searchNode(root.right, point, depth + 1)
        
    def searchWithParent(self, root, point, depth, parent=None):
        if not root:
            return None, None
        if self.areSamePoints(root.point, point):
            return parent, root
        cd = depth % k
        if point[cd] < root.point[cd]:
            return self.searchWithParent(root.left, point, depth + 1, root)
        return self.searchWithParent(root.right, point, depth + 1, root)
        
    def mostRightNode(self, root):
        if not root.right:
            return root
        return self.mostRightNode(root.right)
        
    def deleteNode(self, root, point, depth=0):
        if not root:
            return None
        if self.areSamePoints(root.point, point):
            # leaf node
            if not root.left and not root.right:
                return None
            # 1-child node
            elif not root.left:
                return root.right
            elif not root.right:
                return root.left
            # 2-child node
            replaceRoot = self.mostRightNode(root.left)
            root.point = replaceRoot.point
            root.left = self.deleteNode(root.left, replaceRoot.point, depth + 1)
            return root
        cd = depth % k
        if point[cd] < root.point[cd]:
            root.left = self.deleteNode(root.left, point, depth + 1)
        else:
            root.right = self.deleteNode(root.right, point, depth + 1) 
        return root

In [32]:
if __name__ == '__main__':
    kd_tree = Node(None)
    root = None
    points = [[3, 6], [17, 15], [13, 15], [6, 12], [9, 1], [2, 7], [10, 19]]
    n = len(points)
    for point in points:
        root = kd_tree.insertNode(root, point)
    point1 = [10, 19]
    if kd_tree.searchNode(root, point1):
        print("Found")
    else:
        print("Not Found")
    point2 = [12, 19]
    if kd_tree.searchNode(root, point2):
        print("Found")
    else:
        print("Not Found")
    kd_tree.deleteNode(root, [10,19])
    if kd_tree.searchNode(root, point1):
        print("Found")
    else:
        print("Not Found")

Found
Not Found
Not Found
