Skip to content

Commit

Permalink
use NaNMath package to return NaNs instead of throwing DomainErrors. C…
Browse files Browse the repository at this point in the history
…loses #6
  • Loading branch information
mlubin committed Nov 24, 2014
1 parent 5cc7d9a commit 2818354
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
julia 0.3
DualNumbers
NaNMath
Graphs
6 changes: 6 additions & 0 deletions src/ReverseDiffSparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ module ReverseDiffSparse
import Calculus
using DualNumbers
using Base.Meta
# Override basic math functions to return NaN instead of throwing errors.
# This is what NLP solvers expect, and
# sometimes the results aren't needed anyway,
# because the code may compute derivatives wrt constants.
import NaNMath: sin, cos, tan, asin, acos, acosh, atanh, log, log2, log10, lgamma, log1p, pow

if isdir(Pkg.dir("ArrayViews"))
eval(Expr(:import,:ArrayViews))
const subarr = ArrayViews.view
Expand Down
17 changes: 8 additions & 9 deletions src/revmode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,11 @@ function forwardpass(x::ExprNode, expr_out)
for i in 2:length(x.ex.args)
push!(values, forwardpass(x.ex.args[i], expr_out))
end
fcall = Expr(:call, x.ex.args[1], values...)
if x.ex.args[1] == :(^) # Use NaNMath.pow instead of ^
fcall = Expr(:call, :pow, values...)
else
fcall = Expr(:call, x.ex.args[1], values...)
end
push!(expr_out.args, :( $(x.value) = $fcall ))
return x.value
elseif isexpr(x.ex, :curly)
Expand Down Expand Up @@ -284,11 +288,6 @@ forwardpass(x, expr_out) = :(forwardvalue($x, __placevalues, __placeindex_in))
forwardvalue(x::Placeholder, placevalues, placeindex_in) = placevalues[placeindex_in[getplaceindex(x)]]
forwardvalue(x, placevalues, placeindex_in) = float(x)

# better to return NaNs than throw DomainErrors.
# sometimes the results aren't needed anyway,
# because the code may compute derivatives wrt constants.
log(x) = x <= 0 ? NaN : Base.log(x)

function revpass(x::ExprNode, expr_out)
@assert isexpr(expr_out, :block)
# compute the partial drivative wrt. each expression down the graph
Expand Down Expand Up @@ -342,12 +341,12 @@ function revpass(x::ExprNode, expr_out)
if k == 2 # base
exponent = getvalue(p.ex.args[3])
push!(expr_out.args,
:( $(x.deriv) += $(p.deriv)*$exponent*$(x.value)^($exponent-1) ))
:( $(x.deriv) += $(p.deriv)*$exponent*pow($(x.value),$exponent-1) ))
else
@assert k == 3
base = getvalue(p.ex.args[2])
push!(expr_out.args,
:( $(x.deriv) += $(p.deriv)*$base^($(x.value))*log($base) ))
:( $(x.deriv) += $(p.deriv)*pow($base,$(x.value))*log($base) ))
end
elseif f == :(/)
if k == 2 # numerator
Expand All @@ -358,7 +357,7 @@ function revpass(x::ExprNode, expr_out)
@assert k == 3 # denominator
numer = getvalue(p.ex.args[2])
push!(expr_out.args,
:( $(x.deriv) += -1*$(p.deriv)*$numer*($(x.value))^(-2) ))
:( $(x.deriv) += -1*$(p.deriv)*$numer*pow($(x.value),-2) ))
end
else
# try one of the derivative rules
Expand Down
7 changes: 7 additions & 0 deletions test/test_grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ fval = fg([-3.0],out)
@test_approx_eq fval (-3)^2
@test_approx_eq out[1] 2*-3

y = -2
ex = @processNLExpr y^x[1]
fg = genfgrad_simple(ex)
fval = fg([0.3],out)
@test isnan(fval)


# zeros in products
ex = @processNLExpr prod{ x[i], i = 1:2 }
fg = genfgrad_simple(ex)
Expand Down

0 comments on commit 2818354

Please sign in to comment.