In [8]:
import numpy as np
import queue
import heapq


class VPTRee:

    def __init__(self, points, distance_fn):
        self.build(points)
        self.distance = distance_fn

    def build(self, points):

        if len(points) == 0:
            self.isleaf = True
            self.vp = None
            self.tau = 0
            return
        
        if len(points) == 1:
            self.isleaf = True
            self.vp = points[0]
            self.tau = 0
            return

        self.isleaf = False
        
        vp_idx = np.random.randint(0, len(points))
        vp = points[vp_idx]
        
        other_points = np.delete(points, vp_idx, axis=0)
        distances = self.distance(vp, other_points) # N distances
        med = np.median(distances)

        self.tau = med
        self.vp = vp

        self.left = VPTRee(other_points[distances <= med])
        self.right = VPTRee(other_points[distances > med])

    def search_knn(self, point, k, heap = []):

        if self.vp is None: # if no element, nothing to compare against
            return
        
        # compare the node with the other candidates
        mu = self.distance(self.vp, point)

        if len(heap) < k:
            heapq.heappush(heap, (-mu, self.vp))
        elif mu < -heap[0][0]:
            heapq.heappushpop(heap, (-mu, self.vp))

        if self.isleaf: # if leaf, don't compare with subtrees
            return
        
        # pruning
        if mu <= self.tau:
            self.left.search_knn(point, k, heap)
            if -heap[0][0] > self.tau - mu or len(heap) < k:
                self.right.search_knn(point, k, heap)

        else:
            self.right.search_knn(point, k, heap)
            if -heap[0][0] > mu - self.tau or len(heap) < k:
                self.right.search_knn(point, k, heap)

from torch import Tensor
import torch

def distance_fn(t1: Tensor, t2: Tensor):
    if t2.ndim > t1.ndim: return distance_fn(t2, t1)
    trgt_shape = [1] * (t2.ndim - t1.ndim) + list(t2.shape)
    return ((t2.view(*trgt_shape) - t1)**2).sum(-1)

In [5]:
import torch
torch.rand(10)

tensor([0.4857, 0.9776, 0.7542, 0.6113, 0.9356, 0.1930, 0.6962, 0.9267, 0.0754,
        0.7629])

In [9]:
import numpy as np
import vptree

# Define distance function.
def euclidean(p1, p2):
  return np.sqrt(np.sum(np.power(p2 - p1, 2)))

def euclidean_torch(p1, p2):
  return 

# Generate some random points.
points = np.random.randn(200000, 10)
query = [.5] * 10

# Build tree in O(n log n) time complexity.
tree = vptree.VPTree(points, euclidean)

# Query single point.
tree.get_nearest_neighbor(query)

# Query n-points.
tree.get_n_nearest_neighbors(query, 10)

# Get all points within certain distance.
%timeit tree.get_all_in_range(query, 3.14)

3.08 s ± 55.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
