In [1]:
import random
from math import sqrt
from tqdm import tqdm
import copy

In [2]:
class kd_node():
    def __init__(self, x: float, y: float, dim=0, parent=None):
        self.x = x
        self.y = y
        self.coords = [x, y]
        self.dim = dim
        self.parent = parent

        self.branches = {
            'up': None,
            'low': None
        }

    def in_low_branch(self, x, y):
        p = [x,y]
        return p[self.dim] < self.coords[self.dim]

# =============================================================================

    def add_element(self, x, y):
        flag = 'low' if self.in_low_branch(x, y) else 'up'

        if self.branches[flag] is None :
            dim = (self.dim + 1) % 2
            self.branches[flag] = kd_node(x, y, dim, parent = self)
        else:
            self.branches[flag].add_element(x, y)

# =============================================================================

    def find_min(self, dim):
        if self.dim == dim:
            if self.branches['low'] is None:
                return self
            else:
                return self.branches['low'].find_min(dim)
        else:
            point = self
            if self.branches['low'] is not None:
                min_left = self.branches['low'].find_min(dim)
                point = self.min_on_dimension(point, min_left, dim)
            if self.branches['up'] is not None:
                min_right = self.branches['up'].find_min(dim)
                point = self.min_on_dimension(point, min_right, dim)
            return point

    def min_on_dimension(self, point_a, point_b, dim):
        if point_a.coords[dim] <= point_b.coords[dim]:
            return point_a
        else:
            return point_b

# =============================================================================

    def delete_element(self, x, y):
        if (self.x == x) and (self.y == y):
            if self.branches['up'] is not None:
                min_point = self.branches['up'].find_min(self.dim)
                self.replace_self(min_point)

                self.branches['up'].delete_element(self.x, self.y)

            elif self.branches['low'] is not None:
                min_point = self.branches['low'].find_min(self.dim)
                self.replace_self(min_point)

                self.branches['up'] = self.branches['low']
                self.branches['low'] = None
                self.branches['up'].delete_element(self.x, self.y)

            else:
                self.delete_self()
            return True

        elif self.in_low_branch(x,y):
            if self.branches['low'] is not None:
                return self.branches['low'].delete_element(x,y)
            else:
                return False
        else:
            if self.branches['up'] is not None:
                return self.branches['up'].delete_element(x,y)
            else:
                return False

    def replace_self(self, point):
        self.x, self.y = point.x, point.y
        self.coords = [self.x, self.y]

    def delete_self(self):
        self.parent.cut_leaf(self)

    def cut_leaf(self, leaf):
        for x in self.branches:
            if self.branches[x] is leaf:
                self.branches[x] = None
                return

# =============================================================================

    def search_element(self, x, y):
        if (self.x == x) and (self.y == y):
            return True
        elif self.in_low_branch(x,y):
            if self.branches['low'] is not None:
                return self.branches['low'].search_element(x,y)
            else:
                return False
        else:
            if self.branches['up'] is not None:
                return self.branches['up'].search_element(x,y)
            else:
                return False

# =============================================================================

    def nn_search(self, x, y, k=1, bounding_box=None, current_best=None):
        if bounding_box is None:
            bounding_box = dict()
            for i in range(len(self.coords)):
                bounding_box[i] = {'up': None,
                'low': None}

        flag = self.in_low_branch(x,y)
        flag_a = 'low' if flag else 'up'
        flag_b = 'up' if flag else 'low'

        if current_best is None:
            if self.branches[flag_a] is not None:
                bb = self.calc_branch_bounding_box(bounding_box, flag_a)
                print(f"Searching - {flag_a}\n{bb}\n#=================\n#=================")
                current_best = self.branches[flag_a].nn_search(x, y, k, bb)
                current_best = self.check_self_dist(x, y, k, current_best)
            else:
                dist = sqrt((x - self.x)**2 + (y - self.y)**2)
                print(f"found leaf [{self.x}, {self.y}]- dist {dist}")
                
                current_best = [{
                    'point': self,
                    'distance': dist
                }]

        else:
            current_best = self.check_self_dist(x, y, k, current_best)

            if self.branches[flag_a] is not None:
                bb = self.calc_branch_bounding_box(bounding_box, flag_a)
                if self.check_branch(x, y, current_best=current_best, bbox=bb):
                #check branch -- pass cur best
                    current_best = self.branches[flag_a].nn_search(x, y, k, bb, current_best=current_best)
            
        if self.branches[flag_b] is not None:
            bb = self.calc_branch_bounding_box(bounding_box, flag_b)
            if self.check_branch(x, y, current_best=current_best, bbox=bb):
                #check branch -- pass cur best
                current_best = self.branches[flag_b].nn_search(x, y, k, bb, current_best=current_best)
                    
        return current_best

    def calc_branch_bounding_box(self, boundind_box, branch):
        bb = copy.deepcopy(boundind_box)
        bb[self.dim][branch] = self.coords[self.dim]
        return bb

    def check_self_dist(self, x, y, k, current_best):
        if len(current_best) < k:
            point = {
                'point': self,
                'distance': dist
            }
            current_best.append(point)
            return current_best
        else:
            dist = sqrt((x - self.x)**2 + (y - self.y)**2)
            worst_best_d = -1
            worst_id = None
            for i in range(len(current_best)):
                if current_best[i]['distance'] > worst_best_d:
                    worst_best_d = current_best[i]['distance']
                    worst_id = i
            
            if worst_best_d > dist:
                point = {
                    'point': self,
                    'distance': dist
                }
                current_best.append(point)
                del current_best[worst_id]
            return current_best

    def check_branch(self, x, y, current_best, bbox):
        worst_best_d = -1
        for i in range(len(current_best)):
            if current_best[i]['distance'] > worst_best_d:
                worst_best_d = current_best[i]['distance']

        coords = [x, y]
        dist = 0
        for dim in bbox:
            if (bbox[dim]['up'] is not None)  and (bbox[dim]['up'] > coords[dim]):
                dim_dist = (bbox[dim]['up'] - coords[dim]) ** 2
                dist += dim_dist
            elif (bbox[dim]['low'] is not None)  and (bbox[dim]['low'] < coords[dim]):
                dim_dist = (bbox[dim]['low'] - coords[dim]) ** 2
                dist += dim_dist
        dist = sqrt(dist)
        return (dist < worst_best_d)

# =============================================================================

    def print(self, lv=0):
        print(f"[{self.x}, {self.y}]", end='')
        flag = False
        for x in self.branches:
            if self.branches[x] is not None:
                flag = True
                break
        if not flag:
            print('')
            return
        else:
            print(':')
            for x in self.branches:
                blank = ''
                for i in range(lv+1):
                    blank += '  '
                print(f"{blank}{x}:", end='')
                
                if self.branches[x] is None:
                    print("-null-")
                else:
                    self.branches[x].print(lv+1)

In [3]:
class kd_tree():
    def __init__(self):
        self.root = None
    
    def add_element(self, x, y):
        if self.root is not None:
            self.root.add_element(x, y)
        else:
            self.root = kd_node(x, y, dim=0, parent=self)
    
    def delete_element(self, x, y):
        if self.root is not None:
            return self.root.delete_element(x,y)
        else:
            return False
    
    def cut_leaf(self, leaf):
        self.root = None   
    
    def storage(self):
        # huh?
        pass
    def build(self):
        # huh? 
        pass

    def update(self):
        pass

    def search(self):
        if self.root is not None:
            return self.root.search_element(x,y)
        else:
            return False

    def knn_search(self):
        pass

In [4]:
kdtree = kd_node(random.random(),random.random())

In [5]:
for i in range(100):
    x = random.randrange(100)
    y = random.randrange(100)
    print(f"({x}, {y})")
    kdtree.add_element(x, y)


(79, 52)
(61, 8)
(98, 0)
(5, 96)
(83, 7)
(32, 80)
(6, 39)
(1, 27)
(76, 75)
(12, 86)
(2, 17)
(91, 11)
(88, 19)
(91, 11)
(6, 54)
(44, 84)
(70, 40)
(1, 92)
(22, 49)
(64, 34)
(95, 89)
(33, 17)
(83, 93)
(99, 23)
(67, 28)
(14, 60)
(53, 15)
(92, 18)
(79, 17)
(57, 43)
(39, 71)
(61, 39)
(50, 62)
(59, 57)
(74, 77)
(65, 21)
(59, 62)
(33, 20)
(44, 83)
(21, 5)
(65, 94)
(22, 98)
(77, 39)
(99, 67)
(42, 46)
(98, 25)
(57, 60)
(84, 42)
(94, 94)
(64, 28)
(28, 27)
(83, 44)
(87, 16)
(18, 15)
(46, 98)
(67, 33)
(71, 56)
(63, 93)
(72, 1)
(2, 55)
(44, 6)
(27, 77)
(95, 0)
(53, 52)
(75, 34)
(4, 42)
(93, 33)
(1, 28)
(60, 63)
(97, 71)
(65, 44)
(52, 79)
(66, 95)
(78, 7)
(58, 81)
(99, 16)
(20, 16)
(27, 59)
(21, 6)
(62, 23)
(40, 58)
(95, 81)
(3, 64)
(49, 17)
(7, 39)
(52, 82)
(24, 36)
(49, 0)
(58, 20)
(70, 56)
(57, 78)
(87, 86)
(26, 17)
(88, 76)
(19, 23)
(39, 87)
(84, 98)
(19, 21)
(53, 0)
(86, 78)


In [6]:
kdtree.print()

[0.009144861671059767, 0.15529158421052114]:
  up:[79, 52]:
    up:[5, 96]:
      up:[32, 80]:
        up:[12, 86]:
          up:[44, 84]:
            up:[95, 89]:
              up:-null-
              low:[83, 93]:
                up:[65, 94]:
                  up:[94, 94]:
                    up:[66, 95]:
                      up:[84, 98]
                      low:-null-
                    low:-null-
                  low:[22, 98]:
                    up:[46, 98]
                    low:[63, 93]
                low:[87, 86]:
                  up:-null-
                  low:[39, 87]
            low:[44, 83]:
              up:[58, 81]:
                up:[95, 81]:
                  up:-null-
                  low:[52, 82]
                low:-null-
              low:-null-
          low:-null-
        low:[76, 75]:
          up:[99, 67]:
            up:[97, 71]:
              up:-null-
              low:[88, 76]:
                up:[86, 78]
                low:-null-
            low:

In [7]:
if kdtree.delete_element(77, 56):
    print("DUCK yes")
else:
    print("dammit")

dammit


In [8]:
print (kdtree.nn_search(2,53.5))

Searching - up
{0: {'up': 0.009144861671059767, 'low': None}, 1: {'up': None, 'low': None}}
Searching - up
{0: {'up': 0.009144861671059767, 'low': None}, 1: {'up': 52, 'low': None}}
Searching - low
{0: {'up': 0.009144861671059767, 'low': 5}, 1: {'up': 52, 'low': None}}
Searching - low
{0: {'up': 0.009144861671059767, 'low': 5}, 1: {'up': 52, 'low': 92}}
Searching - up
{0: {'up': 2, 'low': 5}, 1: {'up': 52, 'low': 92}}
found leaf [3, 64]- dist 10.547511554864494
[{'point': <__main__.kd_node object at 0x0000028D4B5C5A58>, 'distance': 1.5}]


In [9]:
test_ar = []
for i in range(5):
    point = {'a': i,
    't': 'hi'}
    test_ar.append(point)
print(test_ar)
del test_ar[2]
print("\n==========================================================\n")
print(test_ar)

[{'a': 0, 't': 'hi'}, {'a': 1, 't': 'hi'}, {'a': 2, 't': 'hi'}, {'a': 3, 't': 'hi'}, {'a': 4, 't': 'hi'}]


[{'a': 0, 't': 'hi'}, {'a': 1, 't': 'hi'}, {'a': 3, 't': 'hi'}, {'a': 4, 't': 'hi'}]
