In [52]:
class Node:
    def __init__(self, point=None, axis=-1):
        # The index of the original point in the dataset (initialized to -1 if not set)
        self.point = point
        # Pointers to left and right child nodes, initialized to None
        self.next = [None, None]
        # Axis of comparison (dimension for this node's split)
        self.axis = axis


In [53]:
class Point:
    def __init__(self, coordinates):
        self.coordinates = coordinates

    def __repr__(self):
        return f"Point({', '.join(map(str, self.coordinates))})"
    
    def __getitem__(self, axis):
        return self.coordinates[axis]
    
    def __eq__(self, other):
        return self.coordinates == other.coordinates

    def __lt__(self, other, axis):
        return self[axis] < other[axis]
    def __len__(self):
        return len(self.coordinates)


In [1]:

from graphviz import Digraph
import math
import heapq
class KDTree:
    def __init__(self, points, dim):
        self.dim = dim
        self.root = None
        self.build_tree(points)
        

    def build_tree(self, points):
        def build(points, axis=0):
            if len(points) == 0:
                return None
            
            points.sort(key=lambda point: point[axis])  # Sort points along the current axis
            
            median_idx = len(points) // 2
            median_point = points[median_idx]
            
            node = Node(point=median_point, axis=axis)
            
            node.next[0] = build(points[:median_idx], (axis + 1) % self.dim)
            node.next[1] = build(points[median_idx + 1:], (axis + 1) % self.dim)
            
            return node
        
        self.root = build(points)

    def insert(self, node, point, axis=0):
        if node is None:
            return Node(point=point, axis=axis)
        
        if point[axis] < node.point[axis]:
            node.next[0] = self.insert(node.next[0], point, (axis + 1) % self.dim)
        else:
            node.next[1] = self.insert(node.next[1], point, (axis + 1) % self.dim)

        return node
    
    def nn_search(self, query):
        """
        Perform nearest neighbor search for the query point.
        Returns the closest point and its distance.
        """
        def nn_search_recursive(node, query, guess, min_dist):
            if node is None:
                return guess, min_dist
            
            # Compute squared Euclidean distance to the current node
            dist_sq = sum((query[i] - node.point[i]) ** 2 for i in range(self.dim))
            
            # Update the best guess if current node is closer
            if dist_sq < min_dist:
                guess, min_dist = node.point, dist_sq
            
            axis = node.axis
            diff = query[axis] - node.point[axis]
            
            # Choose the branch to explore first
            primary_branch = node.next[0] if diff < 0 else node.next[1]
            other_branch = node.next[1] if diff < 0 else node.next[0]
            
            # Search the primary branch
            guess, min_dist = nn_search_recursive(primary_branch, query, guess, min_dist)
            
            # Check if the other branch could contain a closer point
            if diff ** 2 < min_dist:
                guess, min_dist = nn_search_recursive(other_branch, query, guess, min_dist)
            
            return guess, min_dist

        # Start recursive search from the root
        guess, min_dist = nn_search_recursive(self.root, query, None, float('inf'))
        return guess, math.sqrt(min_dist)
    def knn_search(self, query, k, return_dist_sq=False):
        """
        Perform k-nearest neighbors search for the query point.
        Returns a list of the k closest points and optionally their squared distances.
        """
        def dist_sq_func(p1, p2):
            return sum((p1[i] - p2[i]) ** 2 for i in range(self.dim))

        def knn_search_recursive(node, query, k, heap, axis=0, tiebreaker=1):
            if node is None:
                return
            
            # Calculate squared distance
            dist_sq = dist_sq_func(query, node.point)
            dx = query[axis] - node.point[axis]

            # Push the current node into the heap if it qualifies
            if len(heap) < k:
                heapq.heappush(heap, (-dist_sq, tiebreaker, node.point))
            elif dist_sq < -heap[0][0]:
                heapq.heappushpop(heap, (-dist_sq, tiebreaker, node.point))

            # Recursively search child nodes
            next_axis = (axis + 1) % self.dim
            primary_branch = node.next[0] if dx < 0 else node.next[1]
            other_branch = node.next[1] if dx < 0 else node.next[0]

            # Search the primary branch
            knn_search_recursive(primary_branch, query, k, heap, next_axis, (tiebreaker << 1))
            
            # Check whether the other branch might contain closer points
            #still could contain third neasert neighbour even if far
            if dx ** 2 < -heap[0][0] or len(heap) < k:
                knn_search_recursive(other_branch, query, k, heap, next_axis, (tiebreaker << 1) | 1)

        # Initialize a max-heap for KNN
        heap = []
        knn_search_recursive(self.root, query, k, heap)
       
        # Sort the results by distance, using the tiebreaker as a secondary sort key
        results = [(-h[0], h[2]) if return_dist_sq else h[2] for h in sorted(heap)][::-1]
        return results
    def radius_search(self, query, radius):
        """
        Perform a radius search for the query point within the specified radius.
        Returns a list of points within the radius.
        """
        results = []

        def radius_search_recursive(node, query, radius, bounds):
            if node is None:
                return
            
            # Check if current node is in the region
            if self.in_region(node.point, query, radius):
                results.append(node.point)
            
            # Update bounds for children
            axis = node.axis
            left_bounds = bounds[:]
            right_bounds = bounds[:]

            # Create new Point objects for the updated bounds
            left_bounds[1] = Point([min(left_bounds[1][i], node.point[i]) if i == axis else left_bounds[1][i] for i in range(self.dim) ] )
            right_bounds[0] = Point([max(right_bounds[0][i], node.point[i]) if i == axis else right_bounds[0][i] for i in range(self.dim)])

            # Search left subtree if it intersects the region
            if node.next[0] is not None and self.bounds_intersect_region(left_bounds, query, radius):
                radius_search_recursive(node.next[0], query, radius, left_bounds)
            
            # Search right subtree if it intersects the region
            if node.next[1] is not None and self.bounds_intersect_region(right_bounds, query, radius):
                radius_search_recursive(node.next[1], query, radius, right_bounds)

        # Define initial bounds as infinite using Point objects
        initial_bounds = [Point([-float('inf')] * self.dim), Point([float('inf')] * self.dim)]
        radius_search_recursive(self.root, query, radius, initial_bounds)
        return results

    
    @staticmethod
    def in_region(point, query, radius):
        """
        Check if a point lies within the radius of the query point.
        """
        dist_sq = sum((query[i] - point[i]) ** 2 for i in range(len(query)))
        print(point)
        print(dist_sq)
        return dist_sq <= radius **2 

    @staticmethod
    def bounds_intersect_region(bounds, query, radius):
        """
        Check if a bounding box intersects the spherical region of the query.
        """
        dist_sq = 0
        for i in range(len(query)):
            if query[i] < bounds[0][i]:
                dist_sq += (bounds[0][i] - query[i]) ** 2
            elif query[i] > bounds[1][i]:
                dist_sq += (query[i] - bounds[1][i]) ** 2
        return dist_sq <= radius **2


    def __repr__(self):
        def recurse(node):
            if node is None:
                return ""
            return f"Node(axis={node.axis}, point={node.point})\n" + \
                self._recurse_children(node, "  ")

        return recurse(self.root)
    
    def _recurse_children(self, node, indent):
        left = self._format_child(node.next[0], indent + "L-- ")
        right = self._format_child(node.next[1], indent + "R-- ")
        return left + right
    
    def _format_child(self, child, indent):
        if child is None:
            return ""
        return indent + f"Node(axis={child.axis}, point={child.point})\n" + \
               self._recurse_children(child, indent)

    def visualize_tree(self):
        dot = Digraph(format='png', engine='dot')
        
        def add_nodes(dot, node, parent=None):
            if node:
                node_name = f"Node{str(node.point)}"
                dot.node(node_name, label=f"{node.point}")
                if parent:
                    dot.edge(parent, node_name)
                if node.next[0]:
                    add_nodes(dot, node.next[0], node_name)
                if node.next[1]:
                    add_nodes(dot, node.next[1], node_name)
        
        # Start adding nodes from the root
        add_nodes(dot, self.root)
        
        return dot



ModuleNotFoundError: No module named 'graphviz'

In [None]:
import os
# Add Graphviz's bin directory to PATH
os.environ["PATH"] += os.pathsep + r"C:\Users\mahah\Graphviz-12.2.1-win64\bin"

In [125]:
# Create points
p1 = Point([3, 6])
p2 = Point([17, 15])
p3 = Point([13, 15])
p4 = Point([6, 12])
p5 = Point([9, 1])
p6 = Point([2, 7])
p7 = Point([10, 19]) 

# Build KD-Tree
points = [p1, p2, p3, p4, p5, p6, p7]
tree = KDTree(points, dim=2)

# Inserting a new point
new_point = Point([12, 10])
tree.insert(tree.root, new_point)
print(0)
print(tree)
# Visualize the k-d tree
dot = tree.visualize_tree()
dot.render('kd_tree', view=True)   # This will save the tree as a PNG image





0
Node(axis=0, point=Point(9, 1))
  L-- Node(axis=1, point=Point(2, 7))
  L-- L-- Node(axis=0, point=Point(3, 6))
  L-- R-- Node(axis=0, point=Point(6, 12))
  R-- Node(axis=1, point=Point(17, 15))
  R-- L-- Node(axis=0, point=Point(13, 15))
  R-- L-- L-- Node(axis=1, point=Point(12, 10))
  R-- R-- Node(axis=0, point=Point(10, 19))



'kd_tree.png'

In [78]:
import heapq

# Start with an empty heap
heap = [(1,4,4), (9,5,0), (-1,7,3), (-1,10,4)]

# Convert list into a heap
heapq.heapify(heap)



print("Heap after heappushpop:", heap ) # Heap after adding 3 and popping smallest element


Heap after heappushpop: [(-1, 7, 3), (-1, 10, 4), (1, 4, 4), (9, 5, 0)]


In [126]:
# Nearest neighbor test
query_point = Point([10, 10])  # Query point
nearest_point, distance = tree.nn_search(query_point)

# Print results
print(f"Query point: {query_point}")
print(f"Nearest point: {nearest_point}")
print(f"Distance: {distance:.2f}")

Query point: Point(10, 10)
Nearest point: Point(12, 10)
Distance: 2.00


In [128]:
radius = 5
print("Points within radius:", tree.radius_search(query_point, radius))

Point(9, 1)
82
Point(2, 7)
73
Point(3, 6)
65
Point(6, 12)
20
Point(17, 15)
74
Point(13, 15)
34
Point(12, 10)
4
Point(10, 19)
81
Points within radius: [Point(6, 12), Point(12, 10)]


In [129]:
result= tree.knn_search(query_point,4,True)
print(result)

[(4, Point(12, 10)), (20, Point(6, 12)), (34, Point(13, 15)), (65, Point(3, 6))]


In [None]:
class KDTree(object):

    def __init__(self, points, dim, dist_sq_func=None):
        if dist_sq_func is None:
            dist_sq_func = lambda a, b: sum((x - b[i]) ** 2 
                for i, x in enumerate(a))
                
        def make(points, i=0):
            if len(points) > 1:
                points.sort(key=lambda x: x[i])
                i = (i + 1) % dim
                m = len(points) >> 1
                return [make(points[:m], i), make(points[m + 1:], i), 
                    points[m]]
            if len(points) == 1:
                return [None, None, points[0]]
        
        def add_point(node, point, i=0):
            if node is not None:
                dx = node[2][i] - point[i]
                for j, c in ((0, dx >= 0), (1, dx < 0)):
                    if c and node[j] is None:
                        node[j] = [None, None, point]
                    elif c:
                        add_point(node[j], point, (i + 1) % dim)

        import heapq
        def get_knn(node, point, k, return_dist_sq, heap, i=0, tiebreaker=1):
            if node is not None:
                dist_sq = dist_sq_func(point, node[2])
                dx = node[2][i] - point[i]
                if len(heap) < k:
                    heapq.heappush(heap, (-dist_sq, tiebreaker, node[2]))
                elif dist_sq < -heap[0][0]:
                    heapq.heappushpop(heap, (-dist_sq, tiebreaker, node[2]))
                i = (i + 1) % dim
                # Goes into the left branch, then the right branch if needed
                for b in (dx < 0, dx >= 0)[:1 + (dx * dx < -heap[0][0])]:
                    get_knn(node[b], point, k, return_dist_sq, 
                        heap, i, (tiebreaker << 1) | b)
            if tiebreaker == 1:
                return [(-h[0], h[2]) if return_dist_sq else h[2] 
                    for h in sorted(heap)][::-1]

        def walk(node):
            if node is not None:
                for j in 0, 1:
                    for x in walk(node[j]):
                        yield x
                yield node[2]

        self._add_point = add_point
        self._get_knn = get_knn 
        self._root = make(points)
        self._walk = walk

    def __iter__(self):
        return self._walk(self._root)
        
    def add_point(self, point):
        if self._root is None:
            self._root = [None, None, point]
        else:
            self._add_point(self._root, point)

    def get_knn(self, point, k, return_dist_sq=True):
        return self._get_knn(self._root, point, k, return_dist_sq, [])

    def get_nearest(self, point, return_dist_sq=True):
        l = self._get_knn(self._root, point, 1, return_dist_sq, [])
        return l[0] if len(l) else None



In [4]:
m = len([1,2,3]) >> 1
print(m)

1
