In [15]:
using NearestNeighbors
using NearestNeighbors: NNTree, Metric, TreeData
using NearestNeighbors: HyperRectangle, intersects, encloses, interpolate, create_bsphere
using StaticArrays: MVector

# KDTrees

Rewriting the KD Tree implementation from [NearestNeighbors.jl](https://github.com/KristofferC/NearestNeighbors.jl/blob/master/src/ball_tree.jl) to learn the package types.

In [18]:
struct TreeData
    last_node_size::Int
    leafsize::Int
    n_leafs::Int
    n_internal_nodes::Int
    cross_node::Int
    offset::Int
    offset_cross::Int
    last_full_node::Int
end

function TreeData(data::AbstractVector{V}, leafsize) where V
    n_dim, n_p = length(V), length(data)
    
    # If number of points is zero
    n_p == 0 && return TreeData(0, 0, 0, 0, 0, 0, 0, 0)
    
    # using ceiling over allocates n_leafs
    n_leafs = ceil(Int, n_p / leafsize) # Will work for RP if we split at median
    n_internal_nodes = n_leafs - 1 # Why? (Hmm makes sense 1 + 2 + 4 + ... + 2^n = 2^(n+1) - 1)
    leafrow = floor(Int, log2(nleafs)) # n_leafs = 2^k and we want k
    cross_node = 2^(leafrow + 1) # Index of the last node?
    # Leftover points. Total points mod leafsize
    last_node_size = n_p % leafsize
    if last_node_size == 0
        # If the number of datapoints is exactly divisible by leafsize
        # the last node will be full
        last_node_size = leafsize 
    end
    
    # This only happens when n_p / leafsize is a power of 2?
    if cross_node >= n_internal_nodes + n_leafs
        cross_node = div(crossnode, 2) # Not sure what this is
    end
    
    offset = 2*(n_leafs - 2^leafrow) - 1
    k1 = (offset - n_internal_nodes - 1) * leafsize + last_node_size + 1
    k2 = -cross_node * leafzie + 1
    last_full_node = n_leafs + n_internal_nodes
    return TreeData(last_node_size, leafsize, n_leafs, n_internal_nodes, k1, k2, last_full_node)
end
    

LoadError: cannot assign a value to variable NearestNeighbors.TreeData from module Main

In [9]:
struct KDNode{T}
    lo::T
    hi::T
    split_val::T
    split_dim::Int
end

struct KDTree{V <: AbstractVector, M <: MinkowskiMetric} < NNTree{V, M}
    data::Vector{V}
    hyper_rec::HyperRectange
    indices::Vector{Int}
    metric::M
    nodes::Vector{KDNode{T}}
    tree_data::TreeData
    reordered::Bool
end

"""
    KDTree(data [, metric = Euclidean(); leafsize = 10]) -> kdtree
Creates a `KDTree` from the data using the given `metric` and `leafsize`.
The `metric` must be a `MinkowskiMetric`.
"""
function KDTree(data::AbstractVector{V},
                metric::M = Euclidian();
                leafsize::Int = 10,
                storedata::Bool = true,
                reorder::Bool = true,
                reorderbuffer::Vector{V} = Vector{V}()) where {V <: AbstractArray, M <: MinkowskiMetric}
    # Reorder if the reorder buffer is not empty or if we are storing data and 
    # reorder is true. If we are not storing data or the reorder arg is false, 
    # we don't reorder
    reorder = !isempty(reorderbuffer) || (storedata ? reorder : false)

    tree_data = TreeData(data, leafsize)
    n_d = length(V)
    n_p = length(data)

    indices = collect(1:n_p)
    # Allocate space for nodes. Only stores the internal nodes. No leaf nodes are stored
    nodes = Vector{KDNode{eltype(V)}}(undef, tree_data.n_internal_nodes)

    # Allocate space for reordered data
    if reorder
        indices_reordered = Vector{Int}(undef, n_p)
        if isempty(reorderbuffer)
            data_reordered = Vector{V}(undef, n_p)
        else
            data_reordered = reorderbuffer
        end
    else
        # Dummy variables
        indices_reordered = Vector{Int}()
        data_reordered = Vector{V}()
    end
    
    # Check metric parameters
    if metric isa Distances.UnionMetrics
        p = parameters(metric)
        if p !== nothing && length(p) != length(V)
            throw(ArgumentError(
                "dimension of input points:$(length(V)) and metric parameter:$(length(p)) must agree"))
        end
    end
    
    # Create first bounding hyper rectangle that bounds all the input points
    hyper_rec = compute_bbox(data)
    
    # Call the recursive KDTree builder
    build_KDTree(1, data, data_reordered, hyper_rec, nodes, indices, indices_reordered,
                 1, length(data), tree_data, reorder)
    if reorder
        data = data_reordered
        indices = indices_reordered
    end
    
    return KDTree(storedata ? data : similar(data, 0), hyper_rec, indices, metric, nodes, tree_data, reorder)
end

# Constructor to turn data into a static vector, then call the constructor above
function KDTree(data::AbstractVecOrMat{T},
                 metric::M = Euclidean();
                 leafsize::Int = 10,
                 storedata::Bool = true,
                 reorder::Bool = true,
                 reorderbuffer::Matrix{T} = Matrix{T}(undef, 0, 0)) where {T <: AbstractFloat, M <: MinkowskiMetric}
    dim = size(data, 1)
    npoints = size(data, 2)
    # Makse static vector of points
    # Val contains constants that will not change at runtime
    points = copy_svec(T, data, Val(dim))
    if isempty(reorderbuffer)
        # Create an empty reorderbuffer that is a static vector
        reorderbuffer_points = Vector{SVector{dim,T}}()
    else
        # Copy reorder buffer into a static vector
        reorderbuffer_points = copy_svec(T, reorderbuffer, Val(dim))
    end
    # Call the other constructor
    KDTree(points, metric, leafsize = leafsize, storedata = storedata, reorder = reorder,
           reorderbuffer = reorderbuffer_points)
end

function build_KDTree(index::Int,
                      data::AbstractVector{V},
                      data_reordered::Vector{V},
                      hyper_rec::HyperRectangle,
                      nodes::Vector{KDNode{T}},
                      indices::Vector{Int},
                      indices_reordered::Vector{Int},
                      low::Int, # Lowest index
                      high::Int, # Highest index
                      tree_data::TreeData,
                      reorder::Bool) where {V <: AbstractVector, T}
    n_p = high - low + 1 # points left
    
    # Base case: we have a small enough number of points
    if n_p <= tree_data.leafsize
        if reorder
            # Store all the points in a leaf node continuously in memory in data_reordered to improve cache locality.
            # Also stores the mapping to get the index into the original data from the reordered data.
            reorder_data!(data_reordered, data, index, indices_reordered, tree_data)
        end
        return
    end
    
    # We split the tree such that one of the sub trees has exactly 2^p points
    # and such that the left sub tree always has more points.
    # This means that we can deterministally (with just some comparisons)
    # find if we are at a leaf node and how many    
    mid_idx = find_split(low, tree_data.leafsize, n_p)
    
    split_dim = 1
    max_spread = zero(T)
    # find the dimension and spread where spread is maximal
    for d in 1:length(V)
        spread = hyper_rec.maxes[d] - hyper_rec.mins[d]
        if spread > max_spread
            max_spread = spread
            splitdim = d
        end
    end
    
    select_spec!(indices, mid_idx, low, high, data, split_dim)
    
    # Does this mean the data is sorted?
    split_val = data[indices[mid_idx]][split_dim]
    
    lo = hyper_rec.mins[split_dim]
    hi = hyper_rec.maxes[split_dim]
    
    # Initialize a the next node
    nodes[index] = KDNode{T}(lo, hi, split_val, split_dim)
    
    # Call the left sub tree with an updated hyper rectangle
    hyper_rec.maxes[split_dim] = split_val
    build_KDTree(getleft(index), data, data_reordered, hyper_rec, nodes, indices, indices_reordered, mid_idx - 1, treedata, reorder)
    hyper_rec.maxes[split_dim] = hi # Restore the hyper rectangle

    # Call the right sub tree with an updated hyper rectangle
    hyper_rec.mins[split_dim] = split_val
    build_KDTree(getright(index), data, data_reordered, hyper_rec, nodes,
                  indices, indices_reordered, mid_idx, high, tree_data, reorder)
    # Restore the hyper rectangle
    hyper_rec.mins[split_dim] = lo
end

@inline getleft(i::Int) = 2i
@inline getright(i::Int) = 2i + 1
    
    
    
    

In [17]:
undef

UndefInitializer(): array initializer with undefined values

In [19]:
Val()

true

In [47]:
#We will need n_leafs leaf nodes to store all the rest of the data
n_p = 65
leafsize = 5
n_leafs = ceil(Int, n_p/leafsize)


13

In [48]:
k = floor(Int, log2(n_leafs)) 

3

In [49]:
rest = n_leafs - 2^k #Empty nodes in the row

5

`select(v, k[, ord])`
Find the element in position k in the sorted vector v without sorting, according to ordering ord (default: Sort.Forward).

`select!(v, k[, ord])`
Version of select which permutes the input vector in place.