In [21]:
# struct kd tree
struct KDTreeNode
    point::Vector{Float64}
    left::Union{KDTreeNode, Nothing} # In Julia, Nothing the the type, nothing is an instance of Nothing
    right::Union{KDTreeNode, Nothing}
end

In [23]:
function build_kdtree(points::Matrix{Float64}, depth::Int=0)::Union{KDTreeNode, Nothing} # define paramether type with {}; return value using"::" and following the definition of the function
    # For matrix n by k, n is the number of points, k is the data dimension
    # If there are no points, return nothing
    if size(points, 1) == 0
        return nothing
    end

    # Select axis based on depth
    k = size(points, 2)  # Number of dimensions
    axis = mod(depth, k) + 1  # Cycle through dimensions. mod()is non-negative remainder; % can be negative

    # Sort points by the current axis and choose median
    sorted_points = points[sortperm(points[:,axis]),:]
    median_idx = div(size(points, 1), 2) + 1
    median_point = sorted_points[median_idx, :]

    # Create the node and recursively build left and right subtrees
    KDTreeNode(
        median_point,
        build_kdtree(sorted_points[1:median_idx-1, :], depth+1),
        build_kdtree(sorted_points[median_idx+1:end, :], depth+1)
    )
end

build_kdtree (generic function with 4 methods)

In [24]:
# Function to print KD-tree (for debugging)
function print_kdtree(node::Union{KDTreeNode, Nothing}, depth::Int=0)
    if node === nothing
        return
    end

    println("  " ^ depth, "Point: ", node.point)
    print_kdtree(node.left, depth + 1)
    print_kdtree(node.right, depth + 1)
end

print_kdtree (generic function with 2 methods)

In [26]:
points = Float64.(rand(1:10, 10, 3))  # 10 points in 3D space; Has to specify parameter type
kdtree = build_kdtree(points)
println("KD-Tree:")
print_kdtree(kdtree)

KD-Tree:
Point: [6.0, 10.0, 2.0]
  Point: [4.0, 2.0, 3.0]
    Point: [2.0, 1.0, 8.0]
      Point: [1.0, 2.0, 5.0]
    Point: [3.0, 10.0, 7.0]
      Point: [5.0, 4.0, 1.0]
  Point: [7.0, 9.0, 6.0]
    Point: [8.0, 8.0, 9.0]
      Point: [9.0, 5.0, 6.0]
    Point: [8.0, 9.0, 3.0]
