In [1]:
using DataStructures
# using StaticArrays
using BenchmarkTools
using Profile
using Random
import Base: insert!, delete!, haskey, push!, @propagate_inbounds
import DataStructures: AVLTreeNode, RBTreeNode, SplayTreeNode

Classic B-Tree 

In [5]:
const MEMMOVE = false
const BINSEARCH = true
const S = 64

64

In [6]:
mutable struct Node{K}
    data::Vector{K}
    children::Vector{Union{Missing, Node{K}}}
    numData::Int
end

# note: data can use missing treatment too if not isBits

OptNode{K} = Union{Missing, Node{K}}

@inline Node{K}() where K = Node{K}(Vector{K}(undef, S), Vector{OptNode{K}}(missing, S + 1), 0)

@inline isLeaf(node::Node{K}) where K = ismissing(node.children[1])

mutable struct Tree{K}
    root::Node{K}
end

Tree{K}() where K = Tree{K}(Node{K}())

In [3]:
macro unroll_fgeq(i::Int)
    esc(quote
            $([:(@inbounds if node.data[$i] >= d return $i end) for i in 1:i]...)
        nothing
        end)
end

function fgeq_unrolled(node::Node{K}, d::K) where K
    @unroll_fgeq(64)
    return 65
end

@inline function geq(node::Node{K}, d::K) where K
    v = fgeq_unrolled(node, d)
    if v <= node.numData+1
        return v
    end
    return node.numData+1
end

macro unroll_bfgeq(l::Int, r::Int)
    if l == r
        :(return $l)
    else
        m = (l + r) >> 1
        esc(
            :(if $m <= node.numData && @inbounds node.data[$m] < d
                @unroll_bfgeq($(m+1), $r)
            else
                @unroll_bfgeq($l, $m)
            end)
        )
    end
end

@inline function geq(node::Node{K}, d::K) where K
    @unroll_bfgeq(1, 64)
end

LoadError: UndefVarError: Node not defined

In [8]:
@inline function fgeq(node::Node{K}, d::K) where K
    if BINSEARCH
        l = 1
        h = node.numData + 1
        while l != h
            m = (l + h) >> 1
            if @inbounds node.data[m] < d
                l = m + 1
            else
                h = m
            end
        end
        return l
    else
        i = 1
        @inbounds while i <= node.numData && node.data[i] < d
            i += 1
        end
        return i
    end
end

# candidate for simd, unrolling, ccall, other optimizations
@inline function copy!(a::Vector{T}, sA::Int, b::Vector{T}, sB::Int, eB::Int) where T       
    if MEMMOVE
        unsafe_copyto!(a, sA, b, sB, eB-sB+1)
    else
        for i in 0:eB-sB
            @inbounds a[sA + i] = b[sB + i]
        end
    end
end

# remove references to enable garbage collection
@inline function copy!(a::Vector{Union{Missing, T}}, sA::Int, b::Vector{Union{Missing, T}}, sB::Int, eB::Int) where T       
    if MEMMOVE
        unsafe_copyto!(a, sA, b, sB, eB-sB+1)
    
        @inbounds @simd for i in sB:eB
            @inbounds b[i] = missing
        end
    else
        for i in 0:eB-sB
            @inbounds a[sA + i] = b[sB + i]
            @inbounds b[sB + i] = missing
        end
    end
end

@inline function shiftr!(a::Vector{Union{Missing, T}}, s::Int, e::Int) where T
    if MEMMOVE
        unsafe_copyto!(a, s+1, a, s, e-s+1)
        @inbounds a[s] = missing
    else
        for i in e:-1:s
            @inbounds a[i+1] = a[i]
        end
        @inbounds a[s] = missing
    end
end

@inline function shiftr!(a::Vector{T}, s::Int, e::Int) where T
    if MEMMOVE
        unsafe_copyto!(a, s+1, a, s, e-s+1)
    else
        for i in e:-1:s
            @inbounds a[i+1] = a[i]
        end
    end
end

@inline function shiftl!(a::Vector{Union{Missing, T}}, s::Int, e::Int) where T
    if MEMMOVE
        unsafe_copyto!(a, s-1, a, s, e-s+1)
        @inbounds a[e] = missing
    else
        for i in s:e
            @inbounds a[i-1] = a[i]
        end
        @inbounds a[e] = missing
    end
end

@inline function shiftl!(a::Vector{T}, s::Int, e::Int) where T
    if MEMMOVE
        unsafe_copyto!(a, s-1, a, s, e-s+1)
    else
        for i in s:e
            @inbounds a[i-1] = a[i]
        end
    end
end

shiftl! (generic function with 2 methods)

In [9]:
@inline function findkey(node::Node{K}, d::K) where K
    while true
        i = fgeq(node, d)

        if i > node.numData
            return missing
        end

        if @inbounds node.data[i] == d
           @inbounds return node.data[i]
        end

        if isLeaf(node)
            return missing
        end
        
        node = node.children[i]
    end
end

@inline function haskey(tree::Tree{K}, key::K) where K
    return !ismissing(findkey(tree.root, key))
end

function toList!(l::Vector{K}, node::Node{K}) where K
    for i = 1:node.numData
        if !isLeaf(node)
            toList!(l, node.children[i])
        end
        append!(l, node.data[i])
    end
    if !isLeaf(node)
        toList!(l, node.children[node.numData+1])
    end
end

function toList!(l::Vector{K}, tree::Tree{K}) where K
    toList!(l, tree.root)
end

toList! (generic function with 2 methods)

In [10]:
function lower_bound(node::Node{K}, k::K) where K
    i = fgeq(node, k)
    
    if i > node.numData
        return missing
    end
    
    if isLeaf(node)
        return node.data[i]
    else
        v = lower_bound(node.children[i], k)
        return ismissing(v) ? node.data[i] : v
    end
end

function lower_bound(tree::Tree{K}, k::K) where K
    return lower_bound(tree.root, k)
end

lower_bound (generic function with 2 methods)

In [11]:
@propagate_inbounds function insert!(node::Node{K}, iData::K)::Union{Missing, Tuple{K, Node{K}}}  where K
    i = fgeq(node, iData)

    if i <= node.numData && node.data[i] == iData
        node.data[i] = iData
        return missing
    end
    
    if !isLeaf(node)
        ret = insert!(node.children[i], iData)
        
        if ismissing(ret)
            return missing
        end
        
        iData, iChild = ret
    end
    

    split = missing
    if node.numData == S
        rNode = Node{K}()
        m = S ÷ 2 + 1 # bias left

        if S & 1 == 0 # adjustments for even case
            if i == m # median                
                copy!(rNode.data, 1, node.data, m, node.numData)
                
                if !isLeaf(node)
                    copy!(rNode.children, 2, node.children, m+1, node.numData + 1)
                    rNode.children[1] = iChild
                end
                
                rNode.numData = node.numData - (m - 1)
                
                node.numData = m - 1
                
                return iData, rNode
            elseif i < m # if biased and left
                m -= 1
            end
        end
        
        copy!(rNode.data, 1, node.data, m+1, node.numData)
        
        if !isLeaf(node)
            copy!(rNode.children, 1, node.children, m+1, node.numData + 1)
        end
        
        rNode.numData = node.numData - m
        node.numData = m-1
        split = node.data[m]
        
        if i > m
            i -= m
            node = rNode
        end
    end

    shiftr!(node.data, i, node.numData)
    node.data[i] = iData
    
    # insert!(node.data, i, iData)
    
    # println(node.numData)
    if !isLeaf(node)
        shiftr!(node.children, i+1, node.numData + 1)
        node.children[i+1] = iChild
    end
    
    node.numData += 1
        
    return ismissing(split) ? missing : (split, rNode)
end

function insert!(tree::Tree{K}, iData::K) where K
    ret = insert!(tree.root, iData)
    
    if !ismissing(ret)
        nRoot = Node{K}()
        nRoot.children[1] = tree.root
        nRoot.data[1], nRoot.children[2] = ret
        nRoot.numData = 1
        tree.root = nRoot
    end
end

function push!(tree::Tree{K}, iData::K) where K
    insert!(tree, iData)
end

push! (generic function with 66 methods)

In [12]:
@propagate_inbounds function delete!(node::Node{K}, par::OptNode, pNodeI::Int, dData::K)::Union{Missing, Int, Node{K}} where K
    i = fgeq(node, dData)
    
    if !isLeaf(node)
        if i <= node.numData && node.data[i] == dData
            dNode = node.children[i]

            while !isLeaf(dNode)
                dNode = dNode.children[dNode.numData + 1]
            end
            rValue = dNode.data[dNode.numData]

            node.data[i] = rValue

            dData = rValue
        end
        
        i = delete!(node.children[i], node, i, dData)
        if ismissing(i)
            return missing
        end
    end
    
    if isLeaf(node)
        if i > node.numData || node.data[i] != dData
            return missing
        end
    end
    
    # by this point, assumes data i and children i+1 can be safely removed

    shiftl!(node.data, i+1, node.numData)
    if !isLeaf(node)
        shiftl!(node.children, i+2, node.numData+1)
    end
    
    node.numData -= 1
    
    if node.numData < S / 2
        if pNodeI == -1
            # root case: can't do anything about it
            
            if node.numData == 0
                return node.children[1]
            end
            
            return missing
        end
        
        if pNodeI != 1
            lSib = par.children[pNodeI-1]
            if lSib.numData > S / 2
                shiftr!(node.data, 1, node.numData)
                
                if !isLeaf(node)
                    shiftr!(node.children, 1, node.numData+1)
                    node.children[1] = lSib.children[lSib.numData+1]
                    lSib.children[lSib.numData+1] = missing
                end
                
                node.data[1] = par.data[pNodeI - 1]
                par.data[pNodeI - 1] = lSib.data[lSib.numData]
                
                node.numData += 1
                lSib.numData -= 1
                return missing
            end
        end
            
        if pNodeI != par.numData + 1
            rSib = par.children[pNodeI+1]
            if rSib.numData > S / 2
                node.data[node.numData+1] = par.data[pNodeI]
                par.data[pNodeI] = rSib.data[1]
                shiftl!(rSib.data, 2, rSib.numData)
                
                if !isLeaf(node)
                    node.children[node.numData+2]= rSib.children[1]
                    shiftl!(rSib.children, 2, rSib.numData+1)
                end
                
                node.numData += 1
                rSib.numData -= 1
                return missing
            end
        end
        
        if pNodeI != 1
            rNode = node
            node = par.children[pNodeI-1]
            pNodeI -= 1
        else
            rNode = par.children[pNodeI+1]
        end
        node.data[node.numData + 1] = par.data[pNodeI]

        copy!(node.data, node.numData + 2, rNode.data, 1, rNode.numData)
        
        copy!(node.children, node.numData + 2, rNode.children, 1, rNode.numData+1)
        
        node.numData += 1 + rNode.numData
                
        return pNodeI
    end
    return missing
end

function delete!(tree::Tree{K}, dData::K) where K
    ret = delete!(tree.root, missing, -1, dData)
    if !ismissing(ret)
        tree.root = ret
    end
    return nothing
end

delete! (generic function with 34 methods)

## Minimal Tests

In [13]:
function test_add()
    a = Tree{Int}()
    r = Vector{Int}()
    for i = 1:10000000
        x = rand(Int)
        append!(r, x)
        push!(a, x)
    end

    sort!(r)
    unique!(r)

    l = Vector{Int}()
    toList!(l, a)
    @assert r == l
end

test_add()

In [14]:
function test_delete()
    a = Tree{Int}()

    ro = Vector{Int}()
    ref = Vector{Int}()

    for i = 1:10000
        insert!(a, i)
        append!(ro, i)
        append!(ref, i)
    end

    shuffle!(ro)

    for i = 1:10000
        delete!(a, ro[i])
        deleteat!(ref, findall(x->x==ro[i], ref))

        l = Vector{Int}()
        toList!(l, a)
        @assert l==ref
    end
end

test_delete()

## Synthetic Benchmarks

In [26]:
function insert_random(a, n=10000000)
    for i = 1:n
        x = rand(Int)
        push!(a, x)
    end
end

function find_random(a, n=10000000)
    cnt = 0
    for i = 1:n
        x = rand(Int)
        cnt += haskey(a, x)
    end
    return cnt
end

function delete_helper(a, dorder, num)
    for i = 1:num
        delete!(a, dorder[i])
    end
end

function delete_random(a, num=10000000)
    dorder = Vector{Int}()

    for i = 1:num
        x = rand(Int)
        push!(a, x)
        append!(dorder, x)
    end

    shuffle!(dorder)
    
    @time delete_helper(a, dorder, num)
end

function insert_helper(a, iorder, num)
    for i = 1:num
        x = iorder[i]
        push!(a, x)
    end
end

@noinline function find_helper(a, num)
    cnt = 0
    for i = 1:num
        x = rand(Int)
        cnt += haskey(a, x)
    end
    cnt
end

function ifd(a, iorder, shuffled, num)
    @time insert_helper(a, iorder, num)
    
    @time find_helper(a, num)
    
    @time delete_helper(a, shuffled, num)
end

function ifd_test(a, num=10000000)
    iorder = Vector{Int}()
    
    for i = 1:num
        push!(iorder, rand(Int))
    end
    
    shuffled = shuffle(iorder)
    @time ifd(a, iorder, shuffled, num)
end

GC.gc()

In [12]:
GC.gc()
rbt = RBTree{Int}()
avlt = AVLTree{Int}()
st = SplayTree{Int}()
sset = SortedSet{Int}()
bt = Tree{Int}()

Tree{Int64}(Node{Int64}([0, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0], OptNode{Int64}[missing, missing, missing, missing, missing, missing, missing, missing, missing, missing  …  missing, missing, missing, missing, missing, missing, missing, missing, missing, missing], 0))

In [13]:
GC.gc()

In [14]:
S = 2
for i = 1:10
    GC.gc()
    println(S)
    bt = Tree{Int}()
    ifd_test(bt)
    S *= 2
end

2
 53.257548 seconds (64.45 M allocations: 2.072 GiB, 4.00% gc time)
4
 33.109947 seconds (40.62 M allocations: 1.265 GiB, 2.21% gc time)
8
 22.892390 seconds (28.04 M allocations: 874.742 MiB, 2.72% gc time)
16
 18.718309 seconds (21.19 M allocations: 656.421 MiB, 1.88% gc time)
32
 15.365720 seconds (17.47 M allocations: 556.284 MiB, 1.36% gc time)
64
 13.019537 seconds (15.45 M allocations: 483.564 MiB, 1.17% gc time)
128
 14.505164 seconds (14.35 M allocations: 453.630 MiB, 0.77% gc time)
256
 13.801056 seconds (13.73 M allocations: 447.801 MiB, 0.78% gc time)
512
 17.680178 seconds (13.70 M allocations: 456.582 MiB, 0.50% gc time)
1024
 22.428574 seconds (37.14 M allocations: 820.655 MiB, 0.71% gc time)


In [18]:
rbt = RBTree{Int}()
avlt = AVLTree{Int}()
st = SplayTree{Int}()
sset = SortedSet{Int}()
bt = Tree{Int}()

insert_random(rbt, 1)
insert_random(avlt, 1)
insert_random(st, 1)
insert_random(sset, 1)
insert_random(bt, 1)

GC.gc()

In [28]:
GC.gc()
num = 100
for i = 2:7
    rbt = RBTree{Int}()
    avlt = AVLTree{Int}()
    st = SplayTree{Int}()
    sset = SortedSet{Int}()
    bt = Tree{Int}()
    
    iorder = Vector{Int}()
    
    for i = 1:num
        push!(iorder, rand(Int))
    end
    
    shuffled = shuffle(iorder)
    
    
    println(num)
    println("rb")
    GC.gc()
    ifd(rbt, iorder, shuffled, num)
    
    println("avl")
    GC.gc()
    ifd(avlt, iorder, shuffled, num)
    
    println("st")
    GC.gc()
    ifd(st, iorder, shuffled, num)
    
    println("sset")
    GC.gc()
    ifd(sset, iorder, shuffled, num)
    
    println("bt")
    GC.gc()
    ifd(bt, iorder, shuffled, num)
    num *= 10
end

100
rb
  0.000010 seconds (100 allocations: 6.250 KiB)
  0.000003 seconds
  0.000008 seconds
avl
  0.000171 seconds (932 allocations: 17.688 KiB)
  0.000002 seconds
  0.000145 seconds (729 allocations: 11.391 KiB)
st
  0.000019 seconds (100 allocations: 4.688 KiB)
  0.000004 seconds
  0.000016 seconds
sset
  0.000011 seconds (6 allocations: 14.234 KiB)
  0.000004 seconds
  0.000009 seconds (6 allocations: 3.719 KiB)
bt
  0.000007 seconds (7 allocations: 2.344 KiB)
  0.000002 seconds
  0.000006 seconds
1000
rb
  0.000092 seconds (1000 allocations: 62.500 KiB)
  0.000032 seconds
  0.000072 seconds
avl
  0.002516 seconds (12.71 k allocations: 229.812 KiB)
  0.000030 seconds
  0.002351 seconds (10.66 k allocations: 166.625 KiB)
st
  0.000240 seconds (1000 allocations: 46.875 KiB)
  0.000048 seconds
  0.000241 seconds
sset
  0.000111 seconds (11 allocations: 172.812 KiB)
  0.000053 seconds
  0.000091 seconds (10 allocations: 43.594 KiB)
bt
  0.000066 seconds (91 allocations: 27.281 KiB)
  0

In [None]:
GC.gc()
@time find_random(rbt)
@time find_random(avlt)
@time find_random(st)
@time find_random(sset)
@time find_random(bt)

In [None]:
sset = SortedSet{Int}()
insert_random(sset)

In [None]:
@time find_random(sset)

In [None]:
rbt = RBTree{Int}()
avlt = AVLTree{Int}()
bt = Tree{Int}()

insert_random(rbt, 1)
insert_random(avlt, 1)
insert_random(bt, 1)

GC.gc()

In [None]:
@time delete_random(rbt)
GC.gc()
@time delete_random(avlt)
GC.gc()
@time delete_random(st)
GC.gc()
@time delete_random(sset)
GC.gc()
@time delete_random(bt)
GC.gc()

In [None]:
a = Node{Int}()
a.data[1:5] = [1, 2, 3, 4, 5]
a.numData = 5

In [None]:
geq(a, 1)

In [None]:
@code_lowered g(a, 8)

In [None]:
@inbounds if node.data[m] < d
            l = m + 1
        else
            h = m
        end
    end

## Closest Pair of Points

In [32]:
function lower_bound(node::Union{AVLTreeNode{K}, RBTreeNode{K}, SplayTreeNode{K}}, d::K) where K
    if isnothing(node.data)
        return missing
    end
    
    # println(node)
    if node.data < d
        if isnothing(node.rightChild)
            return missing
        end
        return lower_bound(node.rightChild, d)
    end
    
    if !isnothing(node.leftChild)
        v = lower_bound(node.leftChild, d)
        if !ismissing(v)
            return v
        end
    end
    
    return node.data
end

function lower_bound(tree::Union{AVLTree{K}, RBTree{K}, SplayTree{K}}, d::K) where K
    if isnothing(tree.root)
        return missing
    end
    return lower_bound(tree.root, d)
end

lower_bound (generic function with 4 methods)

In [33]:
Point = Tuple{Int, Int}

function dist(a::Point, b::Point)
    return (a[1] - b[1]) * (a[1] - b[1]) + (a[2] - b[2]) * (a[2] - b[2])
end

function swap(x::Point)
    return (x[2], x[1])
end

function closest_pair_points(ds, pts)
    INF = 2^61 # placeholder for infinity
    sort!(pts)
    
    lI = 1
    N = length(pts)
    
    cDist = INF
    
    for i in 1:N
        # println("i:", i)
        cpt = pts[i]
        
        # in data structure, coordinates are swapped to be sorted by y first
        while lI < i && (cpt[1] - pts[lI][1]) ^ 2 >= cDist
            delete!(ds, swap(pts[lI]))
            # println("delete: ", swap(pts[lI]))
            lI += 1
        end
        
        # println("lI:", lI)
        
        lY = cpt[2] - cDist
        hY = cpt[2] + cDist
        
        candpt = lower_bound(ds, (lY, -INF))
        
        # println(candpt)
        
        while !ismissing(candpt) && candpt[1] <= hY
            cDist = min(cDist, dist(swap(candpt), cpt))
            candpt = lower_bound(ds, (candpt[1], candpt[2] + 1))
        end
        # println("cDist:", cDist)
        
        push!(ds, swap(cpt)) # data structure sorted by y, then x
        # println("add: ", swap(cpt))
    end
    # println("Closest Distance:", cDist)
    return cDist
end

closest_pair_points (generic function with 1 method)

In [34]:
function closest_pair_points_benchmark(N = 10000000)
    MAXR = 2^30
    
    pts = Vector{Point}(undef, N)

    for i = 1:N
        a = (rand(Int) % MAXR, rand(Int) % MAXR)
        # push!(pts, a)
        pts[i] = a
    end
    
    # refDist = Int(2^61)
    # for i in pts
    #     for j in pts
    #         if i == j
    #             continue
    #         end
    #         refDist = min(refDist, dist(i, j))
    #     end
    # end
    
    # println("reference distance:", refDist)
    
    rbt = RBTree{Point}()
    avlt = AVLTree{Point}()
    # st = SplayTree{Point}()
    # sset = SortedSet{Point}()
    bt = Tree{Point}()
    GC.gc()
    
    println(N)
    @time closest_pair_points(rbt, pts)
    @time closest_pair_points(avlt, pts)
    # @time closest_pair_points(st, pts)
    # @time closest_pair_points(sset, pts)
    @time closest_pair_points(bt, pts)
    return Nothing
end

closest_pair_points_benchmark (generic function with 2 methods)

In [37]:
GC.gc()
num = 100
for i = 2:8
    closest_pair_points_benchmark(num)
    num *= 10
end

100
  0.000014 seconds (103 allocations: 7.125 KiB)
  0.000057 seconds (788 allocations: 22.344 KiB)
  0.000008 seconds (3 allocations: 896 bytes)
1000
  0.000134 seconds (1.00 k allocations: 70.453 KiB)
  0.000748 seconds (8.78 k allocations: 251.047 KiB)
  0.000066 seconds (3 allocations: 7.953 KiB)
10000
  0.001350 seconds (10.00 k allocations: 703.266 KiB)
  0.007139 seconds (87.76 k allocations: 2.449 MiB)
  0.000649 seconds (3 allocations: 78.266 KiB)
100000
  0.015772 seconds (100.00 k allocations: 6.867 MiB)
  0.080501 seconds (902.06 k allocations: 25.240 MiB)
  0.007366 seconds (17 allocations: 788.078 KiB)
1000000
  0.142606 seconds (1.00 M allocations: 68.665 MiB)
  0.581340 seconds (8.34 M allocations: 231.505 MiB)
  0.050800 seconds (10 allocations: 7.633 MiB)
10000000
  1.706021 seconds (10.00 M allocations: 686.646 MiB, 3.47% gc time)
  6.845856 seconds (86.00 M allocations: 2.339 GiB, 3.08% gc time)
  0.546135 seconds (565 allocations: 76.554 MiB)
100000000
 18.461646 

In [None]:
closest_pair_points_benchmark()

In [None]:
function gen_random(n=10000000)
    sum = 0
    for i = 1:n
        x = rand(Int)
        sum += x
    end
    sum
end

In [None]:
@time gen_random()