In [10]:
import random
from math import sqrt
from tqdm import tqdm
import copy

In [11]:
class kd_node():
    def __init__(self, x: float, y: float, dim=0, parent=None):
        self.x = x
        self.y = y
        self.coords = [x, y]
        self.dim = dim
        self.parent = parent

        self.branches = {
            'up': None,
            'low': None
        }

    def in_low_branch(self, x, y):
        p = [x,y]
        return p[self.dim] < self.coords[self.dim]

# =============================================================================

    def add_element(self, x, y):
        flag = 'low' if self.in_low_branch(x, y) else 'up'

        if self.branches[flag] is None :
            dim = (self.dim + 1) % 2
            self.branches[flag] = kd_node(x, y, dim, parent = self)
        else:
            self.branches[flag].add_element(x, y)

# =============================================================================

    def find_min(self, dim):
        if self.dim == dim:
            if self.branches['low'] is None:
                return self
            else:
                return self.branches['low'].find_min(dim)
        else:
            point = self
            if self.branches['low'] is not None:
                min_left = self.branches['low'].find_min(dim)
                point = self.min_on_dimension(point, min_left, dim)
            if self.branches['up'] is not None:
                min_right = self.branches['up'].find_min(dim)
                point = self.min_on_dimension(point, min_right, dim)
            return point

    def min_on_dimension(self, point_a, point_b, dim):
        if point_a.coords[dim] <= point_b.coords[dim]:
            return point_a
        else:
            return point_b

# =============================================================================

    def delete_element(self, x, y):
        if (self.x == x) and (self.y == y):
            if self.branches['up'] is not None:
                min_point = self.branches['up'].find_min(self.dim)
                self.replace_self(min_point)

                self.branches['up'].delete_element(self.x, self.y)

            elif self.branches['low'] is not None:
                min_point = self.branches['low'].find_min(self.dim)
                self.replace_self(min_point)

                self.branches['up'] = self.branches['low']
                self.branches['low'] = None
                self.branches['up'].delete_element(self.x, self.y)

            else:
                self.delete_self()
            return True

        elif self.in_low_branch(x,y):
            if self.branches['low'] is not None:
                return self.branches['low'].delete_element(x,y)
            else:
                return False
        else:
            if self.branches['up'] is not None:
                return self.branches['up'].delete_element(x,y)
            else:
                return False

    def replace_self(self, point):
        self.x, self.y = point.x, point.y
        self.coords = [self.x, self.y]

    def delete_self(self):
        self.parent.cut_leaf(self)

    def cut_leaf(self, leaf):
        for x in self.branches:
            if self.branches[x] is leaf:
                self.branches[x] = None
                return

# =============================================================================

    def search_element(self, x, y):
        if (self.x == x) and (self.y == y):
            return True
        elif self.in_low_branch(x,y):
            if self.branches['low'] is not None:
                return self.branches['low'].search_element(x,y)
            else:
                return False
        else:
            if self.branches['up'] is not None:
                return self.branches['up'].search_element(x,y)
            else:
                return False

# =============================================================================

    def nn_search(self, x, y, bounding_box=None, current_best=None):
        #if (self.x == x) and (self.y == y):
        #    
        #    return True
        if bounding_box is None:
            bounding_box = dict()
            for i in range(len(self.coords)):
                bounding_box[i] = {'up': None,
                'low': None}

        flag = self.in_low_branch(x,y)
        flag_a = 'low' if flag else 'up'
        flag_b = 'up' if flag else 'low'

        if current_best is None:
            if self.branches[flag_a] is not None:
                bb = self.calc_branch_bounding_box(bounding_box, flag_a)
                print(f"Searching - {flag_a}\n{bb}\n#=================\n#=================")
                current_best = self.branches[flag_a].nn_search(x, y, bb)
                current_best = self.check_self_dist(x, y, current_best)
            else:
                dist = sqrt((x - self.x)**2 + (y - self.y)**2)
                print(f"found leaf [{self.x}, {self.y}]- dist {dist}")
                current_best = {
                    'point': self,
                    'distance': dist
                }

        else:
            current_best = self.check_self_dist(x, y, current_best)

            if self.branches[flag_a] is not None:
                bb = self.calc_branch_bounding_box(bounding_box, flag_a)
                if self.check_branch(x, y, current_best=current_best, bbox=bb):
                #check branch -- pass cur best
                    current_best = self.branches[flag_a].nn_search(x, y, bb, current_best=current_best)
            
        if self.branches[flag_b] is not None:
            bb = self.calc_branch_bounding_box(bounding_box, flag_b)
            if self.check_branch(x, y, current_best=current_best, bbox=bb):
                #check branch -- pass cur best
                current_best = self.branches[flag_b].nn_search(x, y, bb, current_best=current_best)
                    
        return current_best

    def calc_branch_bounding_box(self, boundind_box, branch):
        bb = copy.deepcopy(boundind_box)
        bb[self.dim][branch] = self.coords[self.dim]
        return bb

    def check_self_dist(self, x, y, current_best):
        dist = sqrt((x - self.x)**2 + (y - self.y)**2)
        if current_best['distance'] > dist:
            current_best = {
                'point': self,
                'distance': dist
            }
        return current_best

    def check_branch(self, x, y, current_best, bbox):
        coords = [x, y]
        dist = 0
        for dim in bbox:
            if (bbox[dim]['up'] is not None)  and (bbox[dim]['up'] > coords[dim]):
                dim_dist = (bbox[dim]['up'] - coords[dim]) ** 2
                dist += dim_dist
            elif (bbox[dim]['low'] is not None)  and (bbox[dim]['low'] < coords[dim]):
                dim_dist = (bbox[dim]['low'] - coords[dim]) ** 2
                dist += dim_dist
        dist = sqrt(dist)
        return (dist < current_best['distance'])

# =============================================================================

    def print(self, lv=0):
        print(f"[{self.x}, {self.y}]", end='')
        flag = False
        for x in self.branches:
            if self.branches[x] is not None:
                flag = True
                break
        if not flag:
            print('')
            return
        else:
            print(':')
            for x in self.branches:
                blank = ''
                for i in range(lv+1):
                    blank += '  '
                print(f"{blank}{x}:", end='')
                
                if self.branches[x] is None:
                    print("-null-")
                else:
                    self.branches[x].print(lv+1)

In [12]:
class kd_tree():
    def __init__(self):
        self.root = None
    
    def add_element(self, x, y):
        if self.root is not None:
            self.root.add_element(x, y)
        else:
            self.root = kd_node(x, y, dim=0, parent=self)
    
    def delete_element(self, x, y):
        if self.root is not None:
            return self.root.delete_element(x,y)
        else:
            return False
    
    def cut_leaf(self, leaf):
        self.root = None   
    
    def storage(self):
        # huh?
        pass
    def build(self):
        # huh? 
        pass

    def update(self):
        pass

    def search(self):
        if self.root is not None:
            return self.root.search_element(x,y)
        else:
            return False

    def knn_search(self):
        pass

In [13]:
kdtree = kd_node(random.random(),random.random())

In [14]:
for i in range(100):
    x = random.randrange(100)
    y = random.randrange(100)
    print(f"({x}, {y})")
    kdtree.add_element(x, y)


(37, 66)
(67, 0)
(81, 97)
(10, 49)
(16, 59)
(1, 78)
(16, 94)
(71, 94)
(38, 56)
(15, 52)
(20, 10)
(6, 66)
(87, 43)
(1, 80)
(2, 53)
(53, 40)
(87, 39)
(11, 44)
(0, 46)
(20, 57)
(71, 29)
(6, 79)
(77, 87)
(46, 30)
(68, 44)
(84, 51)
(95, 21)
(33, 10)
(23, 38)
(17, 7)
(99, 51)
(34, 87)
(92, 72)
(90, 27)
(14, 56)
(9, 21)
(51, 98)
(94, 99)
(47, 92)
(36, 71)
(18, 25)
(9, 52)
(12, 70)
(14, 65)
(69, 47)
(71, 55)
(52, 41)
(8, 94)
(44, 73)
(7, 5)
(79, 24)
(82, 66)
(58, 48)
(90, 62)
(39, 62)
(52, 40)
(82, 89)
(47, 39)
(63, 91)
(37, 65)
(43, 47)
(14, 91)
(31, 24)
(14, 55)
(73, 32)
(99, 42)
(44, 94)
(87, 22)
(83, 25)
(71, 43)
(63, 17)
(31, 13)
(42, 73)
(24, 67)
(0, 2)
(77, 46)
(95, 9)
(39, 7)
(1, 42)
(75, 20)
(48, 52)
(74, 77)
(10, 25)
(2, 95)
(26, 73)
(32, 43)
(97, 21)
(2, 99)
(24, 72)
(84, 52)
(69, 32)
(15, 50)
(76, 4)
(84, 20)
(13, 89)
(56, 48)
(80, 81)
(24, 73)
(15, 85)
(80, 35)


In [15]:
kdtree.print()

[0.7198352510531157, 0.47107488383552765]:
  up:[37, 66]:
    up:[81, 97]:
      up:[92, 72]:
        up:[94, 99]:
          up:-null-
          low:[82, 89]
        low:[82, 66]
      low:[1, 78]:
        up:[16, 94]:
          up:[71, 94]:
            up:[51, 98]:
              up:-null-
              low:[44, 94]
            low:[77, 87]:
              up:[80, 81]
              low:[34, 87]:
                up:[47, 92]:
                  up:[63, 91]
                  low:-null-
                low:-null-
          low:[1, 80]:
            up:[8, 94]:
              up:[14, 91]:
                up:-null-
                low:[13, 89]:
                  up:[15, 85]
                  low:-null-
              low:[2, 95]:
                up:[2, 99]
                low:-null-
            low:[6, 79]
        low:[6, 66]:
          up:[36, 71]:
            up:[44, 73]:
              up:[74, 77]
              low:[42, 73]:
                up:[26, 73]:
                  up:-null-
             

In [16]:
if kdtree.delete_element(77, 56):
    print("DUCK yes")
else:
    print("dammit")

dammit


In [18]:
print (kdtree.nn_search(2,53.5))

Searching - up
{0: {'up': 0.7198352510531157, 'low': None}, 1: {'up': None, 'low': None}}
Searching - low
{0: {'up': 0.7198352510531157, 'low': None}, 1: {'up': None, 'low': 66}}
Searching - low
{0: {'up': 0.7198352510531157, 'low': 67}, 1: {'up': None, 'low': 66}}
Searching - up
{0: {'up': 0.7198352510531157, 'low': 67}, 1: {'up': 49, 'low': 66}}
Searching - low
{0: {'up': 0.7198352510531157, 'low': 16}, 1: {'up': 49, 'low': 66}}
Searching - up
{0: {'up': 0.7198352510531157, 'low': 16}, 1: {'up': 52, 'low': 66}}
Searching - up
{0: {'up': 2, 'low': 16}, 1: {'up': 52, 'low': 66}}
Searching - low
{0: {'up': 2, 'low': 16}, 1: {'up': 52, 'low': 56}}
found leaf [9, 52]- dist 7.158910531638177
{'point': <__main__.kd_node object at 0x000001B10BC86F98>, 'distance': 0.5}
