In [42]:
using Random
using Statistics
using AbstractTrees

### A simple implementation of a KDTree

In [1]:
abstract type Rule end

mutable struct KDRule <: Rule
    dim::Int
    max_in_left::Float64
end

function KDRule(X::Array{T, 2}, dim::Int)
    dimvals = X[dim, :]
    max_in_left = median(dimvals)
    KDRule(dim, max_in_left)
end
    
inleft(r::KDRule, x) = x[r.dim] <= r.max_in_left

mutable struct Node{T}
    data::Array{T, 2}
    dim::Int
    npoints::Int
    indexes::Array{Int, 1}
    rule::Rule
    right::Node{T}
    left::Node{T}
    Node() where {T} = new{T}()
    function Node(X::Array{T, 2}) where {T}
        d, n = size(X)
        new{T}(X, d, n, 1:n)
    end
    function Node(X::Array{T, 2}, idxs::Array{Int, 1}) where {T, F}
        return new{T}(X, size(X)[1], length(idxs), idxs)
    end
end

In [2]:
X = rand(10, 100)
Node(X)

Node{Float64}([0.6841242995036696 0.6891073053228145 … 0.38401684979955153 0.24276299792194989; 0.15409183843897067 0.6978541245945087 … 0.6792014417177474 0.6000248594757245; … ; 0.14100081712940749 0.7506436554172047 … 0.7542783367268702 0.010295316288215783; 0.8617297505507733 0.45741212222007444 … 0.6812273582171562 0.07050116104119852], 10, 100, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  91, 92, 93, 94, 95, 96, 97, 98, 99, 100], #undef, #undef, #undef)

In [29]:
function recursive_kd_build(node::Node{T}, n0::Int, d::Int) where T
    """
    recursively builds kd tree
    
    input:
        node (Node) : the node we are working with
        n0 (int)    : random splitting location
        d (ind)     : dimension ? gosh idk what d is
    """
    # check if node points is less than n0
    if node.npoints <= n0
        # if they are then return the node
        return node
    end
    # Incement current dimension
    d += 1
    
    # check if d has inremented beyond node dim plus one
    if d == (node.dim + 1) 
        # if so then set d back to 1
        d = 1 
    end
    # Divide indexes
    n = length(node.indexes)
    dimvals = reshape(node.data[d, node.indexes], :)
    sortidx = sortperm(dimvals)
    mi = isodd(n) ? n ÷ 2 + 1 : n ÷ 2
    median = isodd(n) ? dimvals[mi] : (dimvals[mi] + dimvals[mi]) / 2 
    lidxs = node.indexes[sortidx][1:mi]
    ridxs = node.indexes[sortidx][mi+1:end]
    node.rule = KDRule(d, median)
    node.left = recursive_kd_build(Node(node.data, lidxs), n0, d)
    node.right = recursive_kd_build(Node(node.data, ridxs), n0, d)
    return node
end
    

recursive_kd_build (generic function with 1 method)

In [55]:
@time X = rand(10, 1000000);
root = Node(X)

  0.090769 seconds (2 allocations: 76.294 MiB, 34.17% gc time)


Node{Float64}([0.6531569706754898 0.14438346034148197 … 0.8806061005199359 0.7991063483967205; 0.6864117364790194 0.4814068685489876 … 0.2705885405547266 0.3461310700659086; … ; 0.5193534354443301 0.2039213556017614 … 0.348128116243144 0.8161948890528856; 0.8250022699312176 0.357104671392134 … 0.17114950825420516 0.16732525747310967], 10, 1000000, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  999991, 999992, 999993, 999994, 999995, 999996, 999997, 999998, 999999, 1000000], #undef, #undef, #undef)

In [56]:
@time recursive_kd_build(root, 10, 1);

  1.723220 seconds (1.31 M allocations: 737.966 MiB, 3.67% gc time)


In [45]:
using NearestNeighbors

In [53]:
@time tree = KDTree(X, leafsize=5)

  0.888125 seconds (15.58 k allocations: 174.829 MiB)


KDTree{StaticArrays.SArray{Tuple{10},Float64,1,10},Euclidean,Float64}
  Number of points: 1000000
  Dimensions: 10
  Metric: Euclidean(0.0)
  Reordered: true

In [43]:
function AbstractTrees.children(node::Node)
    if isdefined(node, :left)
        if isdefined(node, :right)
            return (node.left, node.right)
        end
        return (node.left,)
    end
    isdefined(node, :right) && return (node.right,)
    return ()
end

AbstractTrees.printnode(io::IO, node::Node) = print(io, node.npoints)

In [44]:
print_tree(root)

1000000
├─ 500000
│  ├─ 250000
│  │  ├─ 125000
│  │  │  ├─ 62500
│  │  │  │  ├─ 31250
│  │  │  │  │  ⋮
│  │  │  │  │  
│  │  │  │  └─ 31250
│  │  │  │     ⋮
│  │  │  │     
│  │  │  └─ 62500
│  │  │     ├─ 31250
│  │  │     │  ⋮
│  │  │     │  
│  │  │     └─ 31250
│  │  │        ⋮
│  │  │        
│  │  └─ 125000
│  │     ├─ 62500
│  │     │  ├─ 31250
│  │     │  │  ⋮
│  │     │  │  
│  │     │  └─ 31250
│  │     │     ⋮
│  │     │     
│  │     └─ 62500
│  │        ├─ 31250
│  │        │  ⋮
│  │        │  
│  │        └─ 31250
│  │           ⋮
│  │           
│  └─ 250000
│     ├─ 125000
│     │  ├─ 62500
│     │  │  ├─ 31250
│     │  │  │  ⋮
│     │  │  │  
│     │  │  └─ 31250
│     │  │     ⋮
│     │  │     
│     │  └─ 62500
│     │     ├─ 31250
│     │     │  ⋮
│     │     │  
│     │     └─ 31250
│     │        ⋮
│     │        
│     └─ 125000
│        ├─ 62500
│        │  ├─ 31250
│        │  │  ⋮
│        │  │  
│        │  └─ 31250
│        │     ⋮
│        │     
│        └

In [60]:
function findbin(node::Node{T}, x::Array{T, 1}) where T
    if isdefined(node, :left) && isdefined(root, :right)
        if inleft(node.rule, x)
            findbin(node.left, x)
        else
            findbin(node.right, x)
        end
    else
        return node
    end
end

findbin (generic function with 1 method)

In [61]:
@time findbin(root, vec(X[:, 1]))

  0.007233 seconds (4.36 k allocations: 244.964 KiB)


Node{Float64}([0.6531569706754898 0.14438346034148197 … 0.8806061005199359 0.7991063483967205; 0.6864117364790194 0.4814068685489876 … 0.2705885405547266 0.3461310700659086; … ; 0.5193534354443301 0.2039213556017614 … 0.348128116243144 0.8161948890528856; 0.8250022699312176 0.357104671392134 … 0.17114950825420516 0.16732525747310967], 10, 7, [934365, 721938, 324534, 536820, 315229, 439089, 609576], #undef, #undef, #undef)