/
chainrulescore.jl
47 lines (38 loc) · 2 KB
/
chainrulescore.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val})
return getproperty(x, s), Δ -> getproperty_adjoint(Δ, x, s)
end
function getproperty_adjoint(Δ, x, s)
zero_x = zero(similar(x, eltype(Δ)))
zero_x = __setproperty!(zero_x, s, Δ)
return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent())
end
__setproperty!(x, s, Δ) = __setproperty!(Val(false), x, s, Δ)
function __setproperty!(::Val{false}, x, s, Δ)
setproperty!(x, s, Δ)
return x
end
# NOTE: I am not sure how this is avoiding the problem of mutation but if we wrap the
# mutating function into an `rrule` as done here, Zygote computes the correct
# gradient.
__setproperty!(::Val{true}, x, s::Symbol, Δ) = __setproperty!(Val(true), x, Val(s), Δ)
function __setproperty!(::Val{true}, x, s::Val, Δ)
setproperty!(x, s, Δ)
return x
end
function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(__setproperty!), x, s, Δ)
y_, pb_f = ChainRulesCore.rrule_via_ad(cfg, __setproperty!, Val(true), x, s, Δ)
return y_, pb_f
end
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent())
function ChainRulesCore.ProjectTo(ca::ComponentArray)
return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca))
end
(p::ChainRulesCore.ProjectTo{ComponentArray})(dx::AbstractArray) = ComponentArray(p.project(dx), p.axes)
# Prevent double projection
(p::ChainRulesCore.ProjectTo{ComponentArray})(dx::ComponentArray) = dx
function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A, <:NamedTuple}) where {A}
nt = Functors.fmap(ChainRulesCore.backing, ChainRulesCore.backing(t))
return ComponentArray(nt)
end