Skip to content

Commit

Permalink
grad rule for abs
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Aug 18, 2020
1 parent 396b3b6 commit 3d6d677
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ mydiffrule(f, xs...) = begin
f == :// && return mydivrule(xs...)
f == :inv && return mydivrule(1, xs...)[2]
f == :log && return simpliinv(xs...)
f == :abs && return myabsrule(xs...)
f == :sqrt && return mysqrtrule(xs...)
f == :relu && return myrelurule(xs...)
f in BASE_NOGRAD && return map(_->0, xs)
Expand Down Expand Up @@ -302,6 +303,9 @@ simplipow(x, p) = :($x^$p)
myrelurule(x::Number) = x>0 ? 1 : 0
myrelurule(x) = :(ifelse($x>0, 1, 0))

myabsrule(x::Number) = x<0 ? -1 : 1
myabsrule(x) = :(ifelse($x<0, -1, 1)) # matches DiffRules._abs_deriv, which uses signbit(x)

#========== CSE ==========#

# My approach was to look for things occuring twice, biggest first.
Expand Down
7 changes: 7 additions & 0 deletions test/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ dx, dy = _gradient(sum∘mm, x1, y1)
@test dx 2 * ones(3,5) * y1'
@test dy 2 * x1' * ones(3,5)

# abs, abs2
va = [1,-2,3,-4,5]
g1 = ForwardDiff.gradient(v -> sum(abs, 1 .+ v.^2), va)
@test g1 _gradient(v -> (@tullio s := abs(1 + v[i]^2)), va)[1]
g2 = ForwardDiff.gradient(v -> sum(abs2, 1 .+ v.^2), va)
@test g2 _gradient(v -> (@tullio s := abs2(1 + v[i]^2)), va)[1]

# Using zero-dim arrays fails on ReverseDiff & Tracker
# Tracker.gradient(x -> x[], fill(1.0))
# ReverseDiff.gradient(x -> x[], fill(1.0)) # is ambiguous
Expand Down

0 comments on commit 3d6d677

Please sign in to comment.