<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import statistics
import numpy as np

In [2]:
class Node:
    def __init__(self, val, dim = -1):
        self.right = self.left = None
        self.val = val
        self.dim = dim

class KDTree:
    def __init__(self, axis, points):
        self.axis = axis
        self.points = points
        
    def buildKdTree(points, dim = 0, debug=False):
        nextDim = (dim + 1) % 2

        if len(points) == 1:
            return Node(points[0], nextDim)

        # Get points sorted based on the axis
        if dim % 2 == 0:
            # Sorting points on x-axis
            sortedPoints = points[np.lexsort((points[:,1],points[:,0]))]
        else:
            # Sorting points on y-axis
            sortedPoints = points[np.lexsort((points[:,0],points[:,1]))]

        medianIndex = len(points) // 2
        median = sortedPoints[medianIndex]
        leftPoints = sortedPoints[:medianIndex]
        rightPoints = sortedPoints[medianIndex+1:]

        if (debug): 
            print("sorted based on: ", dim , sortedPoints)
            print("median: ", median)
            print("leftPoints: ", leftPoints)
            print("rightPoints: ", rightPoints)

        root = Node(median, dim)
        root.left = buildKdTree(leftPoints, nextDim) if len(leftPoints > 0) else []
        root.right = buildKdTree(rightPoints, nextDim) if len(rightPoints > 0) else []

        return root

In [3]:
def buildKdTree(points, dim = 0, debug=False):
    nextDim = (dim + 1) % 2
    
    if len(points) == 1:
        return Node(points[0], nextDim)

    # Get points sorted based on the axis
    if dim % 2 == 0:
        # Sorting points on x-axis
        sortedPoints = points[np.lexsort((points[:,1],points[:,0]))]
    else:
        # Sorting points on y-axis
        sortedPoints = points[np.lexsort((points[:,0],points[:,1]))]
    
    medianIndex = len(points) // 2
    median = sortedPoints[medianIndex]
    leftPoints = sortedPoints[:medianIndex]
    rightPoints = sortedPoints[medianIndex+1:]
    
    if (debug): 
        print("sorted based on: ", dim , sortedPoints)
        print("median: ", median)
        print("leftPoints: ", leftPoints)
        print("rightPoints: ", rightPoints)
    
    root = Node(median, dim)
    root.left = buildKdTree(leftPoints, nextDim) if len(leftPoints > 0) else []
    root.right = buildKdTree(rightPoints, nextDim) if len(rightPoints > 0) else []
    
    return root

In [4]:
def preorderTraversal(root):
    res, stack = [], [root]
    while stack:
        node = stack.pop()
        if node:
            res.insert(0, node.val)
            stack.append(node.left)
            stack.append(node.right)
    print("res: ", res)
    return res

In [5]:
def dfs(root):
    res, stack = [], [root]
    while stack:
        for i in range(len(stack)):
            node = stack.pop(0)
            if node:
                stack.append(node.left)
                stack.append(node.right)
                print("root: ",node.val, node.dim if node else "None", "left: ", node.left.val if node.left else "None", "right: ", node.right.val if node.right else "None")    

In [6]:
points = np.array([(86, 338), (164, 360), (75, 58), (5,358),(400, 346), (281, 411), (136, 39),(324, 54),(296,332)])

In [7]:
root = buildKdTree(points, 0)

In [8]:
preorder = preorderTraversal(root)

res:  [array([75, 58]), array([136,  39]), array([  5, 358]), array([ 86, 338]), array([296, 332]), array([324,  54]), array([281, 411]), array([400, 346]), array([164, 360])]


In [9]:
dfs(root)

root:  [164 360] 0 left:  [ 86 338] right:  [400 346]
root:  [ 86 338] 1 left:  [136  39] right:  [  5 358]
root:  [400 346] 1 left:  [324  54] right:  [281 411]
root:  [136  39] 0 left:  [75 58] right:  None
root:  [  5 358] 1 left:  None right:  None
root:  [324  54] 0 left:  [296 332] right:  None
root:  [281 411] 1 left:  None right:  None
root:  [75 58] 0 left:  None right:  None
root:  [296 332] 0 left:  None right:  None


In [10]:
dim1Sorted = np.lexsort((points[:,1],points[:,0])) 
dim2Sorted = np.lexsort((points[:,0],points[:,1]))
points[dim1Sorted], points[dim1Sorted][1:], points[dim2Sorted]

(array([[  5, 358],
        [ 75,  58],
        [ 86, 338],
        [136,  39],
        [164, 360],
        [281, 411],
        [296, 332],
        [324,  54],
        [400, 346]]),
 array([[ 75,  58],
        [ 86, 338],
        [136,  39],
        [164, 360],
        [281, 411],
        [296, 332],
        [324,  54],
        [400, 346]]),
 array([[136,  39],
        [324,  54],
        [ 75,  58],
        [296, 332],
        [ 86, 338],
        [400, 346],
        [  5, 358],
        [164, 360],
        [281, 411]]))

In [11]:
def checkRight(node : Node, bottom_left : list, top_right : list) -> bool:
    if (node.val[0] > bottom_left[0]) and (node.val[0] < top_right[0]) and (node.val[1] > bottom_left[1]) and (node.val[1] < top_right[1]):
        return True
    else:
        return False

In [12]:
def checkLeft(node : Node, top_left : list, bottom_right : list) -> bool:
    if (node.val[0] > bottom_right[0]) and (node.val[0] < top_left[0]) and (node.val[1] > top_left[1]) and (node.val[1] < bottom_right[1]):
        return True
    else:
        return False

In [13]:
def isLeaf(node : Node) -> bool:
    if node==None:
        return False
    if node.left==None and node.right==None:
        return True
    return False

In [14]:
def ToNode(node : list) -> Node:
    N = Node(node)
    return N

In [15]:
def CountQueryPoints(node : Node, p1 : list, p2 : list) ->int:
    
    NoOfPoints = 0
    if p1[0]==p2[0]:
        print('Collinear points')
        return
    else:
        slope = (p2[1] - p1[1])/(p2[0] - p1[0])
        
        if slope == 0:
            print('Collinear points')
            return
        
        else:
            #RIGHT DIAGONAL
            if isLeaf(node) and slope > 0:
                if checkRight(node, p1, p2):
                    NoOfPoints += 1
                    
            #LEFT DIAGONAL
            elif isLeaf(node) and slope < 0:
                if checkLeft(node, p1, p2):
                    NoOfPoints += 1
                    
            else:
                #RIGHT DIAGONAL
                if slope > 0:
                    NoOfPoints += int(checkRight(node,p1,p2)==True)
                    
                #LEFT DIAGONAL    
                if slope < 0:
                    NoOfPoints += int(checkLeft(node,p1,p2)==True)
                    
                if type(node.left) == list: 
                    Nl = ToNode(node.left)
                else: 
                    Nl = node.left
                if type(node.right) == list: 
                    Nr = ToNode(node.right)
                else: 
                    Nr = node.right
                NoOfPoints += CountQueryPoints(Nl, p1, p2)
                NoOfPoints += CountQueryPoints(Nr, p1, p2)
            
    return NoOfPoints

In [16]:
points = np.array([(0,1),(3,3),(4,6),(5,5),(10,10),(12,13),(13,14)])

In [17]:
root = buildKdTree(points, 0)

In [18]:
dfs(root)

root:  [5 5] 0 left:  [3 3] right:  [12 13]
root:  [3 3] 1 left:  [0 1] right:  [4 6]
root:  [12 13] 1 left:  [10 10] right:  [13 14]
root:  [0 1] 1 left:  None right:  None
root:  [4 6] 1 left:  None right:  None
root:  [10 10] 1 left:  None right:  None
root:  [13 14] 1 left:  None right:  None


In [19]:
#Case: Right Diagonal
print(CountQueryPoints(root,[-1,-1],[15,15]))

7


In [20]:
#Case: Left Diagonal
print(CountQueryPoints(root,[-15,15],[-1,1]))

0


In [21]:
#Case: Collinear
print(CountQueryPoints(root,[-1,-1],[15,-1]))

Collinear points
None


In [22]:
#Case: Collinear
print(CountQueryPoints(root,[-1,-1],[-1,15]))

Collinear points
None
