Skip to content

Commit

Permalink
simplest prod(xs; dims) gradient
Browse files Browse the repository at this point in the history
Will not treat zeros correctly, see FluxML/Flux.jl#524
  • Loading branch information
mcabbott authored and Michael Abbott committed Aug 10, 2020
1 parent 21e950a commit 7f2b4f3
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/lib/array.jl
Expand Up @@ -358,15 +358,13 @@ Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
Δ -> (zero(xs) .+ Δ, )

Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(xs::TrackedArray; dims=:) = track(prod, xs; dims=dims)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))

@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
@grad prod(xs, dim) = prod(data(xs), dims = dim),
Δ -> (nobacksies(:sum,
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
nothing)
@grad function prod(xs; dims=:)
p = prod(data(xs); dims=dims)
p, Δ -> (p ./ xs .* Δ,)
end

Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)

Expand Down

0 comments on commit 7f2b4f3

Please sign in to comment.