# Matrix Factorization
- $A = P Q$

In [5]:
using Random
using Zygote
using Plots
using SparseArrays

In [8]:
#function cost(P, Q, A)
#    M, N = size(A)
#    # D = size(P)[2]
#    sum((A - P*Q).^2)/(M*N)
#end

function cost_sp(P, Q, MNA)
    M_vec, N_vec, A_vec = A
    L = length(A_vec)
    J = 0.0
    for l in 1:L
        m, n, a = M_vec[l], N_vec[l], A_vec[l]
        J += (a - P[m,:]' * Q[:,n])^2
    end
    J / L    
end

dcost_sp_P(P,Q,MNA) = gradient(x -> cost_sp(x,Q,MNA), P)
dcost_sp_Q(P,Q,MNA) = gradient(x -> cost_sp(P,x,MNA), Q)

function update_sp(P, Q, MNA, N_epoch=100)
    J_a = zeros(Float32, N_epoch)
    J = cost_sp(P, Q, MNA)
    for i in 1:N_epoch
        dP = dcost_sp_P(P,Q,MNA)[1]
        P -= μ * dP
        dQ = dcost_sp_Q(P,Q,MNA)[1]
        Q -= μ * dQ
        J = cost(P,Q,A)
        J_a[i] = J
    end
    P, Q, J_a
end

function train_sp(A_sp, D_featrues, N_epoch=100, μ=0.1)
    M_vec, N_vec, A_vec = findnz(A_sp)
    MNA = (M_vec, N_vec, A_vec)
    
    M, N = size(A_sp)
    τ = 0.01
    P = rand(Float32, M, D_featrues) * τ
    Q = rand(Float32, D_featrues, N) * τ    
    
    update_sp(P, Q, MNA, N_epoch)
end

train_sp (generic function with 3 methods)

In [None]:
M, N, L = 4, 4, 10
A_sp = sparse(rand(1:M,L), rand(1:N,L), rand(Float32,L), M, N)
train_sp(A_sp, 2)

In [270]:
M_vec, N_vec, A_vec = findnz(A_sparse)
MNA = (M_vec, N_vec, A_vec)
MNA

([1, 2, 3, 4, 1, 3, 4], [1, 1, 1, 1, 2, 2, 3], Float32[0.33405036, 1.6304829, 1.4756708, 0.59272987, 0.6740594, 0.97151804, 0.6975906])

In [271]:
M_vec, N_vec, A_vec = findnz(A_sparse)
MNA = (M_vec, N_vec, A_vec)

M, N = size(A_sp)
τ = 0.01
P = rand(Float32, M, D_featrues) * τ
Q = rand(Float32, D_featrues, N) * τ    

update_sp(P, Q, MNA, N_epoch)

LoadError: MethodError: no method matching size(::Tuple{Vector{Int64}, Vector{Int64}, Vector{Float32}})
[0mClosest candidates are:
[0m  size(::Tuple, [91m::Integer[39m) at /opt/julia-1.7.0/share/julia/base/tuple.jl:27
[0m  size([91m::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}[39m) at /opt/julia-1.7.0/share/julia/stdlib/v1.7/LinearAlgebra/src/qr.jl:567
[0m  size([91m::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}[39m, [91m::Integer[39m) at /opt/julia-1.7.0/share/julia/stdlib/v1.7/LinearAlgebra/src/qr.jl:566
[0m  ...