In [1]:
using Random
using LinearAlgebra
using Statistics: quantile
using AbstractTrees

# First try to implement RPTrees

Algorithm based on a PhD [thesis](https://soar.wichita.edu/bitstream/handle/10057/16380/d19007_Keivani.pdf?isAllowed=y&sequence=1) by Omid Keivani.

Here we write pseudocode for **Algorithm 2** (page 8).

In [72]:
mutable struct RPNode{T, N}
    data::Array{T, N}
    numpoints::Int
    α_fractile::T
    random_projection::Array{T, 1}
    leftchild::RPNode{T, N}
    rightchild::RPNode{T, N}
    # Incomplete constructor
    RPNode{T, N}() where {T, N} = new{T, N}()
    RPNode(X::Array{T,N}) where {T, N} = new{T, N}(X, size(X)[2])
    
#     # All fields constructor
#     function RPNode{T, N}(d::Array{T, N}, n::Int, α::T, u::Array{T,1}, l::RPNode{T,N}, r::RPNode{T,N}) where {T, N}
#         return new{T,N}(d, n, α, u, l, r)
#     end
end

In [73]:
""" Recursive function for making Random Projection trees
""" 

function makerptree(node::RPNode{T, N}, n0::Int) where {T, N}
    if node.numpoints <= n0
        return node
    else
        u = rand(size(node.data)[1])
        node.random_projection = u / norm(u)
        α = 0.5 * rand() + .25
        data_proj = transpose(node.data) * node.random_projection
        node.α_fractile = quantile(data_proj, α)
        mask = data_proj .< node.α_fractile
        node.leftchild = makerptree(RPNode(node.data[:, mask]), n0)
        node.rightchild = makerptree(RPNode(node.data[:, .~mask]), n0)
        return node
    end
end

makerptree (generic function with 1 method)

In [74]:
X = randn(10, 100)
root = RPNode(X)

RPNode{Float64,2}([-2.1333141830516813 -0.2630755336295679 … 0.7264670104995037 -1.1804702958649351; 0.9184047407909289 -0.2943171039757304 … 1.316975171933601 -1.0376062553449221; … ; 0.4832684843879911 -0.0967440819339652 … 1.1364390741659327 0.38191336958374894; 0.10261043291696702 -0.31944569385079263 … 1.5168919868483055 1.0980406539624137], 100, 0.0, #undef, #undef, #undef)

In [75]:
makerptree(root, 5);

### Use `AbstractTrees.jl` to print the tree

In [76]:
function AbstractTrees.children(node::RPNode)
    if isdefined(node, :leftchild)
        if isdefined(node, :rightchild)
            return (node.leftchild, node.rightchild)
        end
        return (node.leftchild,)
    end
    isdefined(node, :rightchild) && return (node.rightchild,)
    return ()
end

AbstractTrees.printnode(io::IO, node::RPNode) = print(io, node.numpoints)

In [77]:
print_tree(root)

100
├─ 41
│  ├─ 29
│  │  ├─ 21
│  │  │  ├─ 8
│  │  │  │  ├─ 5
│  │  │  │  └─ 3
│  │  │  └─ 13
│  │  │     ├─ 5
│  │  │     └─ 8
│  │  │        ⋮
│  │  │        
│  │  └─ 8
│  │     ├─ 5
│  │     └─ 3
│  └─ 12
│     ├─ 5
│     └─ 7
│        ├─ 3
│        └─ 4
└─ 59
   ├─ 35
   │  ├─ 11
   │  │  ├─ 5
   │  │  └─ 6
   │  │     ├─ 2
   │  │     └─ 4
   │  └─ 24
   │     ├─ 11
   │     │  ├─ 6
   │     │  │  ⋮
   │     │  │  
   │     │  └─ 5
   │     └─ 13
   │        ├─ 9
   │        │  ⋮
   │        │  
   │        └─ 4
   └─ 24
      ├─ 10
      │  ├─ 4
      │  └─ 6
      │     ├─ 4
      │     └─ 2
      └─ 14
         ├─ 4
         └─ 10
            ├─ 3
            └─ 7
               ⋮
               


### Nearest Neighbor Search

In [78]:
x = rand(10)
distances(X, x) = sum((X .- x).^2, dims=1)
sortidx = sortperm(vec(distances(X, X[:,1])))

100-element Array{Int64,1}:
   1
  80
  76
  48
  98
  67
  56
  12
  16
  60
 100
  74
  41
   ⋮
  18
  49
  37
  72
  34
   4
  90
  14
  71
  68
  53
   9

In [79]:
function findbin(node::RPNode{T,N}, x::Array{T, 1}) where {T, N}
    if isdefined(node, :leftchild) && isdefined(root, :rightchild)
        if transpose(node.random_projection) * x <= node.α_fractile
            findbin(node.leftchild, x)
        else
            findbin(node.rightchild, x)
        end
    else
        return node
    end
end


findbin (generic function with 1 method)

In [80]:
closest = findbin(root, X[:,1])

RPNode{Float64,2}([-2.1333141830516813 -1.5383560806453147 … 0.6074044616634419 -1.1804702958649351; 0.9184047407909289 -0.33014306399090776 … -0.3087451272413634 -1.0376062553449221; … ; 0.4832684843879911 0.3782128170456381 … 0.5759480985186449 0.38191336958374894; 0.10261043291696702 1.0393896212077627 … 0.25767795039177815 1.0980406539624137], 5, 0.0, #undef, #undef, #undef)

In [81]:
distances(closest.data, X[:,1])

1×5 Array{Float64,2}:
 0.0  17.8067  7.93143  14.8051  12.1625

In [82]:
distances(X,X[:,1])[sortidx]

100-element Array{Float64,1}:
  0.0
  7.693277906605389
  7.93143206266029
  9.64083074797982
  9.67817976407559
 10.1981769356077
 10.902707359396965
 11.044106629011242
 11.207501560077135
 12.152213168910517
 12.16252239112806
 12.300036117334868
 12.317874519454953
  ⋮
 31.537746365891675
 31.734668920879663
 32.83419996543021
 33.19249788851822
 34.00149083489789
 34.244048588519874
 34.409661423721836
 35.54997381857867
 37.48068528314821
 37.94192410669977
 39.19735634476281
 39.88050586651641

In [54]:
closest.data[:, end]

10-element Array{Float64,1}:
 -0.2733214321037786
  1.8341238500809505
 -0.4765276405366542
  1.324542870346786
  0.022628409680721804
  0.2833629639205888
  0.5451544539664122
  0.5584974634583276
  1.3305645180166208
  0.9537674957578445