Skip to content

Commit

Permalink
loss functions explained
Browse files Browse the repository at this point in the history
  • Loading branch information
jumutc committed Sep 11, 2015
1 parent ba14ce4 commit 85bc1cb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
17 changes: 12 additions & 5 deletions src/loss_derivative.jl
Expand Up @@ -13,8 +13,8 @@ function modified_huber_loss_derivative(At,yt,w)
eval = evaluate(At,yt,w)
deriv = init(At)

idx1 = find(eval.<0)
idx2 = find((eval.>=0)&(eval.<1))
idx1 = find(eval.<-1)
idx2 = find((eval.>=-1)&(eval.<1))

if ~isempty(idx1)
deriv += 4*hinge_loss(At,yt,idx1)
Expand Down Expand Up @@ -70,8 +70,8 @@ function pinball_loss_derivative(At,yt,w,tau)
d
end

pinball_loss_derivative{T <: Number, AV <: AbstractVector}(At::AV,yt::T,w,tau) = evaluate(At,yt,w) < 1 ? -At.*yt : tau*At.*yt
pinball_loss_derivative{T <: Number}(At::SparseMatrixCSC,yt::T,w,tau) = evaluate(At,yt,w)[1] < 1 ? -At.*yt : tau*At.*yt
pinball_loss_derivative{T <: Number, AV <: AbstractVector}(At::AV,yt::T,w,τ) = evaluate(At,yt,w) < 1 ? -At.*yt : τ*At.*yt
pinball_loss_derivative{T <: Number}(At::SparseMatrixCSC,yt::T,w,τ) = evaluate(At,yt,w)[1] < 1 ? -At.*yt : τ*At.*yt

# LOGISTIC LOSS
logistic_loss{T <: Number}(At,yt::T,w,eval=evaluate(At,yt,w)) = -At.*yt/(exp(eval)+1)
Expand All @@ -92,4 +92,11 @@ loss_derivative(::Type{SQUARED_HINGE}) = squared_hinge_loss_derivative
loss_derivative(::Type{MODIFIED_HUBER}) = modified_huber_loss_derivative
loss_derivative(::Type{PINBALL},tau::Float64) = (At,yt,w) -> pinball_loss_derivative(At,yt,w,tau)
loss_derivative{A <: Algorithm, M <: Euclidean}(alg::RK_MEANS{A,M}) = (At::Matrix,yt,w) -> reduce((d0,i) -> d0 + (w - At[:,i]), zeros(size(At,1),1), 1:1:size(At,2))
loss_derivative{A <: Algorithm, M <: CosineDist}(alg::RK_MEANS{A,M}) = (At::Matrix,yt,w) -> begin idx = find(evaluate(At,yt,w) .<= 0); -sum(At[:,idx],2) end
loss_derivative{A <: Algorithm, M <: CosineDist}(alg::RK_MEANS{A,M}) = (At::Matrix,yt,w) -> begin idx = find(evaluate(At,yt,w) .<= 0); -sum(At[:,idx],2) end

show(io::IO, t::Type{HINGE}) = @printf io "SALSA.HINGE (%s)" "Hinge loss, i.e. l(y,p) = max(0,1 - yp)"
show(io::IO, t::Type{LOGISTIC}) = @printf io "SALSA.HINGE (%s)" "Logistic loss, i.e. l(y,p) = log(1 + exp(-yp))"
show(io::IO, t::Type{LOGISTIC}) = @printf io "SALSA.LEAST_SQUARES (%s)" "Squared loss, i.e. l(y,p) = 1/2*(p - y)^2"
show(io::IO, t::Type{SQUARED_HINGE}) = @printf io "SALSA.SQUARED_HINGE (%s)" "Squared hinge loss, i.e. l(y,p) = max(0,1 - yp)^2"
show(io::IO, t::Type{PINBALL}) = @printf io "SALSA.PINBALL (%s)" "Pinball (quantile) loss, i.e. l(y,p) = τI(yp>=1)yp + I(yp<1)(1 - yp)"
show(io::IO, t::Type{MODIFIED_HUBER}) = @printf io "SALSA.MODIFIED_HUBER (%s)" "Modified huber loss, i.e. l(y,p) = -4I(yp<-1)yp + I(yp>=-1)max(0,1 - yp)^2"
2 changes: 1 addition & 1 deletion test/unit/test_loss_derivative.jl
Expand Up @@ -33,4 +33,4 @@ loss = loss_derivative(LEAST_SQUARES)
loss = loss_derivative(MODIFIED_HUBER)
@test loss([1;2],-1,[2;1]) == [4,8]''
@test loss([1;2],1,[.1;.1]) == [-.7,-1.4]''
@test loss([1 1;2 2],[-1,1],[.1;.1]) == [3.3,6.6]''
@test loss([1 1;2 2],[-1,1],[.1;.2]) == [1,2]''

0 comments on commit 85bc1cb

Please sign in to comment.