In [2]:
using LinearAlgebra
using Random
using FrankWolfe
include("auxiliary_functions.jl")
include("terms_and_polynomials.jl")
include("objective_functions.jl")
include("oracle_avi.jl");

In [16]:
function abm(Data)
    data_with_labels = hcat(Data.A, Data.b)
    m = size(data_with_labels, 1)
    if size(data_with_labels, 1) > size(data_with_labels, 2)
        data_with_labels_squared = hcat(Data.A_squared, Data.A_b)
        bottom_row = vcat(Data.A_b, Data.b_squared)
        bottom_row = bottom_row'
        # (m/2) factor due to way L2 data is saved 
        data_with_labels_squared = (m / 2) * vcat(data_with_labels_squared, bottom_row)
        
        F = svd(data_with_labels_squared)
    else
        F = svd(data_with_labels)
    end
    
    U, S, Vt = F.U, F.S, F.Vt
    coefficient_vector = Vt[:, end]
    loss = (1 / size(Data.A, 1)) * norm(data_with_labels * coefficient_vector, 2)^2
    
    return coefficient_vector, loss
end        

abm (generic function with 1 method)

In [17]:
data = rand(5, 3)
labels = rand(5)
L2, _ = construct_L2Loss(data, labels)
abm(L2)

([-0.3373242244187917, -0.577006638737177, 0.07760056070492272, 0.739766084280543], 0.09119733572975218)

In [4]:
a = randn(3, 4)
b = ones(Int64, 4)

4-element Vector{Int64}:
 1
 1
 1
 1

In [7]:
r = randn(4, 2)
hcat(r, b)

4×3 Matrix{Float64}:
 -0.380353   1.01309    1.0
  0.793804   1.16646    1.0
  0.211649   1.24327    1.0
  0.13102   -0.0713635  1.0

In [12]:
a = randn(5, 4)
U, S, Vt = svd(a)

SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}
U factor:
5×4 Matrix{Float64}:
  0.477206   -0.171469   0.515172   -0.50062
 -0.845353   -0.200028   0.0569193  -0.36866
 -0.211797    0.489445   0.709175   -0.0729521
  0.0919446  -0.358993  -0.208805   -0.626054
  0.0658894   0.74977   -0.429919   -0.464977
singular values:
4-element Vector{Float64}:
 3.5272170206378526
 1.4917590367054454
 1.2625715736577416
 0.517081481618711
Vt factor:
4×4 Matrix{Float64}:
 -0.161915  -0.748572   0.0517593  -0.640894
  0.2419     0.599054  -0.0254448  -0.762871
  0.732871  -0.25057   -0.630006    0.0566367
  0.614954  -0.134125   0.774446    0.0638426

In [13]:
a ≈ U * Diagonal(S) * Vt

false