In [1]:
using LinearAlgebra
using Random

In [3]:
using Test
using FrankWolfe

In [4]:
"""
Runs ABM algorithm to find coefficient vector and computes loss.

# Arguments
- 'oracle_type::String': string denoting which oracle to construct
- 'data::Union{Matrix{Float64}, Matrix{Int64}}': data (O_evaluations)
- 'labels::Union{Matrix{Float64}, Matrix{Int64}, Vector{Float64}, Vector{Int64}}': labels (term_evaluated)
- 'lambda::Union{Float64, Int64}': regularization parameter (if applicable)
- 'data_squared::Union{Matrix{Float64}, Matrix{Int64}}': squared data 
- 'data_labels::Union{Matrix{Float64}, Vector{Float64}}': data' * labels 
- 'labels_squared::Float64': labels' * labels
- 'data_squared_inverse::Union{Matrix{Float64}, Matrix{Int64}, Nothing}': inverse of data_squared (default is nothing)

# Returns
- 'coefficient_vector::Vector{Float64}': coefficient vector minimizing ABM optimization problem
- 'loss::Float64': loss using 'coefficient_vector' 
"""
function abm(data::Union{Matrix{Float64}, Matrix{Int64}}, 
        labels::Union{Matrix{Float64}, Matrix{Int64}, Vector{Float64}, Vector{Int64}},
        data_squared::Union{Matrix{Float64}, Matrix{Int64}}, 
        data_labels::Union{Matrix{Float64}, Vector{Float64}},
        labels_squared::Float64;
        data_squared_inverse::Union{Matrix{Float64}, Matrix{Int64}, Nothing}=nothing)
    data_with_labels = hcat(data, labels)
    m = size(data_with_labels, 1)
    
    if size(data_with_labels, 1) > size(data_with_labels, 2)
        data_squared_with_labels = hcat(data_squared, data_labels)
        bottom_row = vcat(data_labels, labels_squared)
        bottom_row = bottom_row'
        data_squared_with_labels = vcat(data_squared_with_labels, bottom_row)
        F = svd(data_squared_with_labels)
    else
        F = svd(data_with_labels)
    end
    
    U, S, Vt = F.U, F.S, F.Vt
    coefficient_vector = Vt[:, end]
    loss = 1/size(data, 1) * norm(data_with_labels * coefficient_vector, 2)^2
    
    return coefficient_vector, loss
end

abm (generic function with 1 method)

In [5]:
include("oracle_constructors.jl")

abm (generic function with 1 method)

In [7]:
a = ConditionalGradients([1], 2, 3, frank_wolfe)

ConditionalGradients([1.0], 2, 3, FrankWolfe.frank_wolfe)

In [34]:
# implement this ↑ for generic constructors
# also make it so constructors create struct with all necessary information to excavate from

In [34]:
kwargs = [(:epsilon, 0.1), (:max_iteration, 10)]

2-element Vector{Tuple{Symbol, Real}}:
 (:epsilon, 0.1)
 (:max_iteration, 10)

In [35]:
function foo1(;epsilon::T=1, max_iteration::S=2) where {T<:Real, S<:Real}  
    return epsilon * max_iteration
end

foo1 (generic function with 5 methods)

In [36]:
foo1(;kwargs...)

1.0

In [41]:
Foo(x, y; a=1, b=2) = x + y + a + b

Foo (generic function with 1 method)

In [42]:
varargs = [1, 2]
kwargs = [(:a, 3), (:b, 4)]

2-element Vector{Tuple{Symbol, Int64}}:
 (:a, 3)
 (:b, 4)

In [45]:
Foo(varargs...; kwargs...)

10