In [1]:
using LinearAlgebra
using FrankWolfe
include("terms_and_polynomials.jl")
include("auxiliary_functions.jl")
include("objective_functions.jl");

# OAVI Algorithm

In [19]:
"""
Creates OAVI feature transformation fitted to X_train

# Arguments
- 'X_train::Vector{Vector{Float64}}': training data
- 'max_degree::Int64': max degree of polynomials computed (default 10)
- 'psi::Float64': vanishing extent (default 0.1)
- 'epsilon::Float64': accuracy for convex optimizer (default 0.001)
- 'tau::Union{Float64, Int64}': upper bound on norm of coefficient vector

# Returns
- 'X_train_transformed::Vector{Vector{Float64}}': transformed X_train
- 'sets::sets_avi': instance of mutable struct keeping track of sets for AVI 
"""  # line 14
function fit(X_train::Vector{Vector{Float64}}; 
        max_degree::Int64=10, psi::Float64=0.1, epsilon::Float64=0.001, tau::Union{Float64, Int64}=1000,
        lmbda::Float64=0., tol::Float64=0.0001, objective_type::String="L2Loss", region_type::String="L1Ball", 
        oracle_type::String="CG", max_iters::Int64=10000, inverse_hessian_boosting::String="false")

    m, n = length(X_train), length(X_train[1])
    
    sets = SetsOandG([nothing], [nothing], [nothing], [nothing], [nothing],
    zeros(Int64, length(X_train), 1), ones(Float64, length(X_train), 1), [],
    [nothing], zeros(Float64, m, 0), 
    nothing)
    
    degree = 0
    while degree < max_degree
        degree += 1 
        # line 30
        border_terms_raw, border_evaluations_raw, non_purging_indices = construct_border(sets.O_terms, sets.O_terms_evaluations, X_train)
        border_terms = border_terms_raw[:, non_purging_indices]
        border_evaluations = border_evaluations_raw[:, non_purging_indices]
        
        O_indices = []
        leading_terms = []
        G_coefficient_vectors = nothing

        data = sets.O_terms_evaluations
        data_squared = data' * data
        data_squared_inverse = nothing

        for col_idx in 1:size(border_terms, 2)

            if G_coefficient_vectors != nothing
                G_coefficient_vectors = vcat(G_coefficient_vectors, zeros(Float64, 1, size(G_coefficient_vectors, 2)))
            end

            term_evaluated = border_evaluations[:, col_idx] 
            data_term_evaluated = data' * term_evaluated
            term_evaluated_squared = term_evaluated' * term_evaluated
            
            f, grad!, region = nothing, nothing, nothing
            
            if objective_type == "L2Loss"
                Loss, f, grad! = construct_L2Loss(data, term_evaluated; lmbda=lmbda, data_squared=data_squared, labels_squared=term_evaluated_squared, 
                                    data_squared_inverse=data_squared_inverse, data_labels=data_term_evaluated)
            end

            if region_type == "L1Ball"
                region = FrankWolfe.LpNormLMO{1}(tau-1)
            end

            @assert f != nothing "Objective function f not defined."
            @assert grad! != nothing "Gradient of f not defined."
            @assert region != nothing "Feasible region not defined."
# line 67
            # compute initial point 
            x0 = Vector(compute_extreme_point(region, zeros(Float64, size(data, 2))))
            #return size(data)
            #return f(x0)
            #return grad!(similar(x0), x0)

            coefficient_vector = call_oracle(f, grad!, region, x0)
            coefficient_vector_full = vcat(coefficient_vector, [1])
            #data_with_labels = vcat(data, labels)
            #loss = (1 / size(data, 1)) * norm(data_with_labels * coefficient_vector_full, 2)^2
            return "We got it."
        end

    end
end

fit

In [20]:
fit([[1.0, 2.0], 
[1.0, 2.0], 
[1.0, 2.0], 
[1.0, 2.0]])

[33m[1m└ [22m[39m[90m@ FrankWolfe C:\Users\pusty\.julia\packages\FrankWolfe\MDe7s\src\linesearch.jl:367[39m
[33m[1m└ [22m[39m[90m@ FrankWolfe C:\Users\pusty\.julia\packages\FrankWolfe\MDe7s\src\linesearch.jl:367[39m
[33m[1m└ [22m[39m[90m@ FrankWolfe C:\Users\pusty\.julia\packages\FrankWolfe\MDe7s\src\linesearch.jl:367[39m


"We got it."