<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import statistics
import numpy as np

In [2]:
class Node:
    def __init__(self, val, dim = -1):
        self.right = self.left = None
        self.val = val
        self.dim = dim

class KDTree:
    def __init__(self, axis, points):
        self.axis = axis
        self.points = points
        
    def buildKdTree(points, dim = 0, debug=False):
        nextDim = (dim + 1) % 2

        if len(points) == 1:
            return Node(points[0], nextDim)

        # Get points sorted based on the axis
        if dim % 2 == 0:
            # Sorting points on x-axis
            sortedPoints = points[np.lexsort((points[:,1],points[:,0]))]
        else:
            # Sorting points on y-axis
            sortedPoints = points[np.lexsort((points[:,0],points[:,1]))]

        medianIndex = len(points) // 2
        median = sortedPoints[medianIndex]
        leftPoints = sortedPoints[:medianIndex]
        rightPoints = sortedPoints[medianIndex+1:]

        if (debug) 
            print("sorted based on: ", dim , sortedPoints)
            print("median: ", median)
            print("leftPoints: ", leftPoints)
            print("rightPoints: ", rightPoints)

        root = Node(median, dim)
        root.left = buildKdTree(leftPoints, nextDim) if len(leftPoints > 0) else []
        root.right = buildKdTree(rightPoints, nextDim) if len(rightPoints > 0) else []

        return root

In [42]:
def buildKdTree(points, dim = 0, debug=False):
    nextDim = (dim + 1) % 2
    
    if len(points) == 1:
        return Node(points[0], nextDim)

    # Get points sorted based on the axis
    if dim % 2 == 0:
        # Sorting points on x-axis
        sortedPoints = points[np.lexsort((points[:,1],points[:,0]))]
    else:
        # Sorting points on y-axis
        sortedPoints = points[np.lexsort((points[:,0],points[:,1]))]
    
    medianIndex = len(points) // 2
    median = sortedPoints[medianIndex]
    leftPoints = sortedPoints[:medianIndex]
    rightPoints = sortedPoints[medianIndex+1:]
    
    if (debug) 
        print("sorted based on: ", dim , sortedPoints)
        print("median: ", median)
        print("leftPoints: ", leftPoints)
        print("rightPoints: ", rightPoints)
    
    root = Node(median, dim)
    root.left = buildKdTree(leftPoints, nextDim) if len(leftPoints > 0) else []
    root.right = buildKdTree(rightPoints, nextDim) if len(rightPoints > 0) else []
    
    return root

In [11]:
def preorderTraversal(root):
    res, stack = [], [root]
    while stack:
        node = stack.pop()
        if node:
            res.insert(0, node.val)
            stack.append(node.left)
            stack.append(node.right)
    print("res: ", res)
    return res

In [17]:
def dfs(root):
    res, stack = [], [root]
    while stack:
        for i in range(len(stack)):
            node = stack.pop(0)
            if node:
                stack.append(node.left)
                stack.append(node.right)
                print("root: ", node.val, node.dim if node else "None", "left: ", node.left.val if node.left else "None", "right: ", node.right.val if node.right else "None")    

In [21]:
points = np.array([(86, 338), (164, 360), (75, 58), (5,358),(400, 346), (281, 411), (136, 39),(324, 54),(296,332)])

In [43]:
root = buildKdTree(points, 0)

sorted based on:  0 [[  5 358]
 [ 75  58]
 [ 86 338]
 [136  39]
 [164 360]
 [281 411]
 [296 332]
 [324  54]
 [400 346]]
median:  [164 360]
leftPoints:  [[  5 358]
 [ 75  58]
 [ 86 338]
 [136  39]]
rightPoints:  [[281 411]
 [296 332]
 [324  54]
 [400 346]]
sorted based on:  1 [[136  39]
 [ 75  58]
 [ 86 338]
 [  5 358]]
median:  [ 86 338]
leftPoints:  [[136  39]
 [ 75  58]]
rightPoints:  [[  5 358]]
sorted based on:  0 [[ 75  58]
 [136  39]]
median:  [136  39]
leftPoints:  [[75 58]]
rightPoints:  []
sorted based on:  1 [[324  54]
 [296 332]
 [400 346]
 [281 411]]
median:  [400 346]
leftPoints:  [[324  54]
 [296 332]]
rightPoints:  [[281 411]]
sorted based on:  0 [[296 332]
 [324  54]]
median:  [324  54]
leftPoints:  [[296 332]]
rightPoints:  []


In [44]:
preorder = preorderTraversal(root)

res:  [array([75, 58]), array([136,  39]), array([  5, 358]), array([ 86, 338]), array([296, 332]), array([324,  54]), array([281, 411]), array([400, 346]), array([164, 360])]


In [45]:
dfs(root)

root:  [164 360] 0 left:  [ 86 338] right:  [400 346]
root:  [ 86 338] 1 left:  [136  39] right:  [  5 358]
root:  [400 346] 1 left:  [324  54] right:  [281 411]
root:  [136  39] 0 left:  [75 58] right:  None
root:  [  5 358] 1 left:  None right:  None
root:  [324  54] 0 left:  [296 332] right:  None
root:  [281 411] 1 left:  None right:  None
root:  [75 58] 0 left:  None right:  None
root:  [296 332] 0 left:  None right:  None


In [32]:
dim1Sorted = np.lexsort((points[:,1],points[:,0])) 
dim2Sorted = np.lexsort((points[:,0],points[:,1]))
points[dim1Sorted], points[dim1Sorted][1:], points[dim2Sorted]

(array([[  5, 358],
        [ 75,  58],
        [ 86, 338],
        [136,  39],
        [164, 360],
        [281, 411],
        [296, 332],
        [324,  54],
        [400, 346]]),
 array([[ 75,  58],
        [ 86, 338],
        [136,  39],
        [164, 360],
        [281, 411],
        [296, 332],
        [324,  54],
        [400, 346]]),
 array([[136,  39],
        [324,  54],
        [ 75,  58],
        [296, 332],
        [ 86, 338],
        [400, 346],
        [  5, 358],
        [164, 360],
        [281, 411]]))