Modified version of https://github.com/genkuroki/public/blob/main/0016/apricot/julia_translation_of_python_reimpl.ipynb

In [1]:
#using Seaborn
using ScikitLearn: @sk_import
@sk_import datasets: fetch_covtype
#using Random
#using StatsBase: sample

PyObject <function fetch_covtype at 0x000000006A38AC10>

In [2]:
digits_data = fetch_covtype()

Dict{Any, Any} with 6 entries:
  "feature_names" => ["Elevation", "Aspect", "Slope", "Horizontal_Distance_To_H…
  "frame"         => nothing
  "target_names"  => ["Cover_Type"]
  "data"          => [2596.0 51.0 … 0.0 0.0; 2590.0 56.0 … 0.0 0.0; … ; 2384.0 …
  "target"        => Int32[5, 5, 2, 2, 5, 2, 5, 5, 5, 5  …  3, 3, 3, 3, 3, 3, 3…
  "DESCR"         => ".. _covtype_dataset:\n\nForest covertypes\n--------------…

In [3]:
X_digits = permutedims(abs.(digits_data["data"]))
summary(X_digits)

"54×581012 Matrix{Float64}"

In [4]:
"""`calculate_gains!(X, gains, current_values, idxs, current_concave_values_sum)` mutates `gains` only"""
function calculate_gains!(X, gains, current_values, idxs, current_concave_values_sum)
    Threads.@threads for i in eachindex(idxs)
        @inbounds idx = idxs[i]
        @inbounds gains[i] = sum(j -> sqrt(current_values[j] + X[j, idx]), axes(X, 1))
    end
    gains .-= current_concave_values_sum
end

@doc calculate_gains!

`calculate_gains!(X, gains, current_values, idxs, current_concave_values_sum)` mutates `gains` only


In [5]:
function fit_popat(X, k; calculate_gains! = calculate_gains!)
    d, n = size(X)

    cost = 0.0

    ranking = Int[]
    total_gains = Float64[]

    current_values = zeros(d)
    current_concave_values_sum = sum(sqrt, current_values)

    idxs = collect(1:n)

    gains = zeros(n)
    while cost < k
        calculate_gains!(X, gains, current_values, idxs, current_concave_values_sum)

        idx = argmax(gains)
        best_idx = idxs[idx]
        curr_cost = 1.0
        
        cost + curr_cost > k && break

        cost += curr_cost
        # Calculate gains
        gain = gains[idx] * curr_cost

        # Select next
        current_values .+= @view X[:, best_idx]
        current_concave_values_sum = sum(sqrt, current_values)

        push!(ranking, best_idx)
        push!(total_gains, gain)

        popat!(idxs, idx)
    end
    return ranking, total_gains
end

fit_popat (generic function with 1 method)

In [6]:
function fit_bitvector(X, k; calculate_gains! = calculate_gains!)
    d, n = size(X)

    cost = 0.0

    ranking = Int[]
    total_gains = Float64[]

    mask = trues(n) # `false` stands for "masked".
    current_values = zeros(d)
    current_concave_values_sum = sum(sqrt, current_values)

    idxs = collect(1:n)

    gains = zeros(n)
    while cost < k
        calculate_gains!(X, gains, current_values, idxs, current_concave_values_sum)

        idx = argmax(gains)
        best_idx = idxs[idx]
        curr_cost = 1.0
        
        cost + curr_cost > k && break

        cost += curr_cost
        # Calculate gains
        gain = gains[idx] * curr_cost

        # Select next
        current_values .+= @view X[:, best_idx]
        current_concave_values_sum = sum(sqrt, current_values)

        push!(ranking, best_idx)
        push!(total_gains, gain)

        mask[best_idx] = 0
        idxs = findall(mask)
    end
    return ranking, total_gains
end

fit_bitvector (generic function with 1 method)

In [7]:
function fit_f64vector(X, k; calculate_gains! = calculate_gains!)
    d, n = size(X)

    cost = 0.0

    ranking = Int[]
    total_gains = Float64[]

    mask = zeros(n)
    current_values = zeros(d)
    current_concave_values_sum = sum(sqrt, current_values)

    idxs = collect(1:n)

    gains = zeros(n)
    while cost < k
        calculate_gains!(X, gains, current_values, idxs, current_concave_values_sum)

        idx = argmax(gains)
        best_idx = idxs[idx]
        curr_cost = 1.0
        
        cost + curr_cost > k && break

        cost += curr_cost
        # Calculate gains
        gain = gains[idx] * curr_cost

        # Select next
        current_values .+= @view X[:, best_idx]
        current_concave_values_sum = sum(sqrt, current_values)

        push!(ranking, best_idx)
        push!(total_gains, gain)

        mask[best_idx] = 1
        idxs = findall(mask .== 0)
    end
    return ranking, total_gains
end

fit_f64vector (generic function with 1 method)

In [8]:
k = 1000

1000

In [9]:
@time ranking0_pa, gains0_pa = fit_popat(X_digits, k; calculate_gains! = calculate_gains!);

 12.654430 seconds (1.10 M allocations: 74.974 MiB, 1.65% compilation time)


In [10]:
@time ranking0_pa, gains0_pa = fit_popat(X_digits, k; calculate_gains! = calculate_gains!);

 12.403661 seconds (62.81 k allocations: 14.995 MiB)


In [11]:
@time ranking0_pa, gains0_pa = fit_popat(X_digits, k; calculate_gains! = calculate_gains!);

 12.514568 seconds (62.77 k allocations: 14.993 MiB)


In [12]:
@time ranking0_bv, gains0_bv = fit_bitvector(X_digits, k; calculate_gains! = calculate_gains!);

 13.675888 seconds (330.71 k allocations: 4.354 GiB, 2.32% gc time, 0.52% compilation time)


In [13]:
@time ranking0_bv, gains0_bv = fit_bitvector(X_digits, k; calculate_gains! = calculate_gains!);

 13.911387 seconds (64.86 k allocations: 4.340 GiB, 2.07% gc time)


In [14]:
@time ranking0_bv, gains0_bv = fit_bitvector(X_digits, k; calculate_gains! = calculate_gains!);

 14.225798 seconds (64.99 k allocations: 4.340 GiB, 2.08% gc time)


In [15]:
@time ranking0_f64v, gains0_f64v = fit_f64vector(X_digits, k; calculate_gains! = calculate_gains!);

 14.310422 seconds (843.62 k allocations: 4.456 GiB, 2.29% gc time, 1.21% compilation time)


In [16]:
@time ranking0_f64v, gains0_f64v = fit_f64vector(X_digits, k; calculate_gains! = calculate_gains!);

 13.948037 seconds (68.83 k allocations: 4.416 GiB, 2.11% gc time)


In [17]:
@time ranking0_f64v, gains0_f64v = fit_f64vector(X_digits, k; calculate_gains! = calculate_gains!);

 13.963663 seconds (68.86 k allocations: 4.416 GiB, 2.01% gc time)


In [18]:
@show ranking0_pa == ranking0_bv == ranking0_f64v
@show gains0_pa == gains0_bv == gains0_f64v;

ranking0_pa == ranking0_bv == ranking0_f64v = true
gains0_pa == gains0_bv == gains0_f64v = true
