In [39]:
import collections
class KDPoint:
    def __init__(self, point):
        self.point = point
        self.numDims = len(point)
    def get(self, level):
        return self.point[level%self.numDims]
    def __str__(self):
        return "{}".format(self.point)
    
class KDNode:
    def __init__(self, point):        
        self.node = KDPoint(point)
        self.left = None
        self.right = None
    
    def add(self, node):
        self._add(node, 0)
    
    def _add(self, node, level):        
        if node.node.get(level) < self.node.get(level):
            if self.left is None:
                self.left = node
            else:
                self.left._add(node, level+1)
        else:
            if self.right is None:
                self.right = node
            else:
                self.right._add(node, level+1)
                
    def __str__(self):
        return "[point: {}]".format(self.node)
    
class KDTree:
    def __init__(self, point):        
        self.root = KDNode(point)
        
    def add(self, node):
        self.root.add(node)
    
    def nearestNeighbor(self, target):
        return self._nearestNeighbor(self.root, target, 0)
    
    def _nearestNeighbor(self, root, target, depth):
        if root == None:
            return None
        next_branch, other_branch = None, None
        if root.node.get(depth) < target.node.get(depth):
            next_branch = root.left
            other_branch = root.right
        else:
            next_branch = root.right
            other_branch = root.left
        
        temp = self._nearestNeighbor(next_branch, target, depth+1)
        best = self.closest(temp, root, target)
        
        radiusSquared = self.distSquared(best, target)
        
        dist = target.node.get(depth) - root.node.get(depth)
        if radiusSquared >= dist*dist:
            temp = self._nearestNeighbor(other_branch, target, depth+1)
            best = self.closest(temp, best, target)
        return best
    
    def closest(self, node0, node1, target):
        # find closest node to target, either from node0 or node1
        if node0 == None:
            return node1
        if node1 == None:
            return node0
        dist0 = self.distSquared(node0, target)
        dist1 = self.distSquared(node1, target)
        if dist0 < dist1:
            return node0
        else:
            return node1
    
    def distSquared(self, p0, p1):
        total = 0
        numDims = self.root.node.numDims
        for d in range(numDims):
            total += (p0.node.get(d) -p1.node.get(d))**2
        return total
    
    def __str__(self):
        output = ''
        q = collections.deque()
        q.append(self.root)
        while q:
            size = len(q)
            for i in range(size):
                node = q.popleft()
                #output += str(node)
                if node:
                    output += str(node)
                    q.append(node.left)                    
                    q.append(node.right)
                else:
                    output += "[Point: None ]"
            output += "\n"
        return output

In [45]:
def main():
    arr2 = [[50, 50],
           [80, 40],
           [10, 60],
           [51, 38],
           [48, 38],[56, 73]]
    arr3 = [[50, 50, 23],
           [80, 40, 34],
           [10, 60, 67],
           [51, 38, 14],
           [48, 38, 52],[56, 73, 10], [53,20, 32], [50, 11,36]]
    arr = arr3
    tree = KDTree(arr[0])
    for i in range(1, len(arr)):
        node = KDNode(arr[i])
        tree.add(node)
    print(tree)
    test_node = KDNode([50, 78, 14])
    nearest_node = tree.nearestNeighbor(test_node)
    print("neearest of test_node={} is {}".format(test_node, nearest_node))
    for i in range(len(arr)):
        for j in range(i+1, len(arr)):
            dij = tree.distSquared(KDNode(arr[i]), KDNode(arr[j]))
            print("d_{}_{} = {}".format(i, j, dij))
main()

[point: [50, 50, 23]]
[point: [10, 60, 67]][point: [80, 40, 34]]
[point: [48, 38, 52]][Point: None ][point: [51, 38, 14]][point: [56, 73, 10]]
[Point: None ][Point: None ][Point: None ][point: [53, 20, 32]][Point: None ][Point: None ]
[point: [50, 11, 36]][Point: None ]
[Point: None ][Point: None ]

neearest of test_node=[point: [50, 78, 14]] is [point: [56, 73, 10]]
d_0_1 = 1121
d_0_2 = 3636
d_0_3 = 226
d_0_4 = 989
d_0_5 = 734
d_0_6 = 990
d_0_7 = 1690
d_1_2 = 6389
d_1_3 = 1245
d_1_4 = 1352
d_1_5 = 2241
d_1_6 = 1133
d_1_7 = 1745
d_2_3 = 4974
d_2_4 = 2153
d_2_5 = 5534
d_2_6 = 4674
d_2_7 = 4962
d_3_4 = 1453
d_3_5 = 1266
d_3_6 = 652
d_3_7 = 1214
d_4_5 = 3053
d_4_6 = 749
d_4_7 = 989
d_5_6 = 3302
d_5_7 = 4556
d_6_7 = 106


In [47]:
5 - (3==2)

5