In [1]:
import random
from math import sqrt
from tqdm import tqdm
import copy

In [2]:
class kd_node():
    def __init__(self, x: float, y: float, dim=0, parent=None):
        self.x = x
        self.y = y
        self.coords = [x, y]
        self.dim = dim
        self.parent = parent

        self.branches = {
            'up': None,
            'low': None
        }

    def in_low_branch(self, x, y):
        p = [x,y]
        return p[self.dim] < self.coords[self.dim]

# =============================================================================

    def add_element(self, x, y):
        flag = 'low' if self.in_low_branch(x, y) else 'up'

        if self.branches[flag] is None :
            dim = (self.dim + 1) % 2
            self.branches[flag] = kd_node(x, y, dim, parent = self)
        else:
            self.branches[flag].add_element(x, y)

# =============================================================================

    def delete_element(self, x, y):
        if (self.x == x) and (self.y == y):
            if self.branches['up'] is not None:
                min_point = self.branches['up'].find_min(self.dim)
                self.replace_self(min_point)

                self.branches['up'].delete_element(self.x, self.y)

            elif self.branches['low'] is not None:
                min_point = self.branches['low'].find_min(self.dim)
                self.replace_self(min_point)

                self.branches['up'] = self.branches['low']
                self.branches['low'] = None
                self.branches['up'].delete_element(self.x, self.y)

            else:
                self.delete_self()
            return True

        elif self.in_low_branch(x,y):
            if self.branches['low'] is not None:
                return self.branches['low'].delete_element(x,y)
            else:
                return False
        else:
            if self.branches['up'] is not None:
                return self.branches['up'].delete_element(x,y)
            else:
                return False

    # -----------------------------------------------

    def replace_self(self, point):
        self.x, self.y = point.x, point.y
        self.coords = [self.x, self.y]

    def delete_self(self):
        self.parent.cut_leaf(self)

    def cut_leaf(self, leaf):
        for x in self.branches:
            if self.branches[x] is leaf:
                self.branches[x] = None
                return

    # -----------------------------------------------

    def find_min(self, dim):
        if self.dim == dim:
            if self.branches['low'] is None:
                return self
            else:
                return self.branches['low'].find_min(dim)
        else:
            point = self
            if self.branches['low'] is not None:
                min_left = self.branches['low'].find_min(dim)
                point = self.min_on_dimension(point, min_left, dim)
            if self.branches['up'] is not None:
                min_right = self.branches['up'].find_min(dim)
                point = self.min_on_dimension(point, min_right, dim)
            return point

    def min_on_dimension(self, point_a, point_b, dim):
        if point_a.coords[dim] <= point_b.coords[dim]:
            return point_a
        else:
            return point_b

# =============================================================================

    def search_element(self, x, y):
        if (self.x == x) and (self.y == y):
            return True
        elif self.in_low_branch(x,y):
            if self.branches['low'] is not None:
                return self.branches['low'].search_element(x,y)
            else:
                return False
        else:
            if self.branches['up'] is not None:
                return self.branches['up'].search_element(x,y)
            else:
                return False

# =============================================================================

    def knn_search(self, x, y, k=1, bounding_box=None, current_best=None):
        if bounding_box is None:
            bounding_box = dict()
            for i in range(len(self.coords)):
                bounding_box[i] = {'up': None,
                'low': None}

        flag = self.in_low_branch(x,y)
        flag_a = 'low' if flag else 'up'
        flag_b = 'up' if flag else 'low'

        if current_best is None:
            if self.branches[flag_a] is not None:
                bb = self.calc_branch_bounding_box(bounding_box, flag_a)
                current_best = self.branches[flag_a].knn_search(x, y, k, bb)
                current_best = self.check_self_dist(x, y, k, current_best)
            else:
                dist = sqrt((x - self.x)**2 + (y - self.y)**2)
                current_best = [{
                    'point': self,
                    'distance': dist
                }]

        else:
            current_best = self.check_self_dist(x, y, k, current_best)

            if self.branches[flag_a] is not None:
                bb = self.calc_branch_bounding_box(bounding_box, flag_a)
                if self.check_branch(x, y, current_best=current_best, bbox=bb):
                #check branch -- pass cur best
                    current_best = self.branches[flag_a].knn_search(x, y, k, bb, current_best=current_best)
            
        if self.branches[flag_b] is not None:
            bb = self.calc_branch_bounding_box(bounding_box, flag_b)
            if self.check_branch(x, y, current_best=current_best, bbox=bb):
                #check branch -- pass cur best
                current_best = self.branches[flag_b].knn_search(x, y, k, bb, current_best=current_best)
                    
        return current_best

    # -----------------------------------------------

    def calc_branch_bounding_box(self, boundind_box, branch):
        bb = copy.deepcopy(boundind_box)
        bb[self.dim][branch] = self.coords[self.dim]
        return bb

    def check_self_dist(self, x, y, k, current_best):
        dist = sqrt((x - self.x)**2 + (y - self.y)**2)
        if len(current_best) < k:
            point = {
                'point': self,
                'distance': dist
            }
            current_best.append(point)
            return current_best
        else:
            worst_best_d = -1
            worst_id = None
            for i in range(len(current_best)):
                if current_best[i]['distance'] > worst_best_d:
                    worst_best_d = current_best[i]['distance']
                    worst_id = i
            
            if worst_best_d > dist:
                point = {
                    'point': self,
                    'distance': dist
                }
                current_best.append(point)
                del current_best[worst_id]
            return current_best

    def check_branch(self, x, y, current_best, bbox):
        worst_best_d = -1
        for i in range(len(current_best)):
            if current_best[i]['distance'] > worst_best_d:
                worst_best_d = current_best[i]['distance']

        coords = [x, y]
        dist = 0
        for dim in bbox:
            if (bbox[dim]['up'] is not None)  and (bbox[dim]['up'] > coords[dim]):
                dim_dist = (bbox[dim]['up'] - coords[dim]) ** 2
                dist += dim_dist
            elif (bbox[dim]['low'] is not None)  and (bbox[dim]['low'] < coords[dim]):
                dim_dist = (bbox[dim]['low'] - coords[dim]) ** 2
                dist += dim_dist
        dist = sqrt(dist)
        return (dist < worst_best_d)

# =============================================================================

    def print(self, lv=0):
        print(f"[{self.x}, {self.y}]", end='')
        flag = False
        for x in self.branches:
            if self.branches[x] is not None:
                flag = True
                break
        if not flag:
            print('')
            return
        else:
            print(':')
            for x in self.branches:
                blank = ''
                for i in range(lv+1):
                    blank += '  '
                print(f"{blank}{x}:", end='')
                
                if self.branches[x] is None:
                    print("-null-")
                else:
                    self.branches[x].print(lv+1)

In [3]:
class kd_tree():
    def __init__(self):
        self.root = None
    
    def add_element(self, x, y):
        if self.root is not None:
            self.root.add_element(x, y)
        else:
            self.root = kd_node(x, y, dim=0, parent=self)
    
    def delete_element(self, x, y):
        if self.root is not None:
            return self.root.delete_element(x,y)
        else:
            return False
    
    def cut_leaf(self, leaf):
        self.root = None   
    
    def storage(self):
        # huh?
        pass
    def build(self):
        # huh? 
        pass

    def update(self):
        pass

    def search(self):
        if self.root is not None:
            return self.root.search_element(x,y)
        else:
            return False

    def knn_search(self, x, y, k=1):
        if self.root is not None:
            return self.root.knn_search(x,y,k)
        else:
            return []

In [4]:
kdtree = kd_node(random.random(),random.random())

In [5]:
for i in range(100):
    x = random.randrange(100)
    y = random.randrange(100)
    print(f"({x}, {y})", end='')
    kdtree.add_element(x, y)
print('')


(95, 10)(23, 46)(72, 28)(55, 35)(34, 62)(43, 27)(74, 56)(38, 71)(42, 81)(64, 2)(63, 89)(31, 75)(32, 77)(15, 13)(38, 18)(93, 80)(41, 61)(12, 37)(27, 35)(61, 94)(46, 43)(75, 76)(40, 51)(74, 59)(56, 89)(87, 35)(30, 96)(51, 39)(11, 23)(21, 73)(19, 10)(88, 46)(27, 7)(14, 14)(96, 17)(69, 75)(77, 79)(95, 57)(91, 5)(39, 98)(61, 7)(19, 25)(69, 93)(97, 80)(92, 62)(8, 14)(11, 80)(97, 88)(47, 56)(23, 34)(98, 84)(19, 87)(93, 95)(6, 11)(8, 86)(11, 66)(11, 79)(0, 62)(32, 34)(42, 69)(11, 54)(38, 41)(94, 10)(40, 66)(77, 7)(35, 64)(91, 18)(12, 8)(22, 90)(92, 56)(18, 38)(97, 63)(92, 17)(38, 76)(21, 23)(52, 65)(54, 13)(46, 35)(15, 92)(74, 17)(18, 74)(51, 26)(11, 24)(20, 3)(13, 92)(15, 65)(57, 55)(23, 2)(60, 56)(73, 18)(80, 5)(72, 96)(30, 56)(81, 22)(20, 31)(84, 92)(61, 16)(80, 65)(72, 32)(23, 55)


In [6]:
kdtree.print()

[0.5683199801100896, 0.5042890216320932]:
  up:[95, 10]:
    up:[23, 46]:
      up:[72, 28]:
        up:[55, 35]:
          up:[74, 56]:
            up:[63, 89]:
              up:[93, 80]:
                up:[69, 93]:
                  up:[97, 80]:
                    up:[97, 88]:
                      up:[98, 84]
                      low:[93, 95]:
                        up:[72, 96]
                        low:[84, 92]
                    low:-null-
                  low:-null-
                low:[75, 76]:
                  up:[77, 79]:
                    up:-null-
                    low:[95, 57]:
                      up:[97, 63]
                      low:[92, 62]:
                        up:[80, 65]
                        low:[92, 56]
                  low:[74, 59]:
                    up:[69, 75]
                    low:-null-
              low:[61, 94]:
                up:-null-
                low:[56, 89]:
                  up:[60, 56]
                  low:-null-
         

In [7]:
if kdtree.delete_element(77, 56):
    print("DUCK yes")
else:
    print("dammit")

dammit


In [9]:
print (kdtree.knn_search(2,53.5, k=3))

[{'point': <__main__.kd_node object at 0x00000208344246D8>, 'distance': 15.402921800749363}, {'point': <__main__.kd_node object at 0x0000020834424908>, 'distance': 9.013878188659973}, {'point': <__main__.kd_node object at 0x0000020834424828>, 'distance': 8.73212459828649}]


In [9]:
test_ar = []
for i in range(5):
    point = {'a': i,
    't': 'hi'}
    test_ar.append(point)
print(test_ar)
del test_ar[2]
print("\n==========================================================\n")
print(test_ar)

[{'a': 0, 't': 'hi'}, {'a': 1, 't': 'hi'}, {'a': 2, 't': 'hi'}, {'a': 3, 't': 'hi'}, {'a': 4, 't': 'hi'}]


[{'a': 0, 't': 'hi'}, {'a': 1, 't': 'hi'}, {'a': 3, 't': 'hi'}, {'a': 4, 't': 'hi'}]
