-
Notifications
You must be signed in to change notification settings - Fork 10
/
autodiff.jl
185 lines (148 loc) · 8.07 KB
/
autodiff.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# this does basis promotion, unlike Zygote's default for AbstractArrays
Zygote.accum(a::Field, b::Field) = a+b
# this may create a LazyBinaryOp, unlike Zygote's
Zygote.accum(a::FieldOp, b::FieldOp) = a+b
# constant functions, as far as AD is concerned
@nograd ProjLambert
@nograd fieldinfo
@nograd hasfield
# AD for Fourier Fields can be really subtle because such objects are
# still supposed to represent a pixel-space field, despite that
# they're stored as the half-plane real FFT coefficients. this leads
# to needing factors of Npix(=Ny*Nx) and rfft_degeneracy_fac (which
# gives weight of 2 to coefficient which would be there twice in the
# full-plane FFT) in a few cases below to make the logic all work. the
# Npix factor is handled by the Zfac function.
Zfac(::SpatialBasis{Map}, proj::FlatProj) = 1
Zfac(::SpatialBasis{Fourier}, proj::FlatProj) = proj.Ny * proj.Nx
Zfac(L::DiagOp{<:Field{B}}) where {B} = Zfac(B(), L.diag.metadata)
## constructors and getproperty
# these are needed to allow you to "bail out" to the underlyig arrays
# via f.arr and reconstruct them with FlatField(f.arr, f.metadata).
@adjoint function (::Type{F})(arr::A, metadata::M) where {B<:SpatialBasis{Map},M<:FieldMetadata,T,A<:AbstractArray{T},F<:BaseField{B}}
F(arr, metadata), Δ -> (Δ.arr, nothing)
end
@adjoint function (::Type{F})(arr::A, metadata::M) where {B<:SpatialBasis{Fourier},M<:FieldMetadata,T,A<:AbstractArray{T},F<:BaseField{B}}
F(arr, metadata), Δ -> (Δ.arr .* adapt(Δ.storage, T.(rfft_degeneracy_fac(metadata.Ny) ./ Zfac(B(), metadata))), nothing)
end
# the factors here need to cancel the ones in the corresponding constructors above
@adjoint function Zygote.literal_getproperty(f::BaseField{B}, ::Val{:arr}) where {B<:SpatialBasis{Map}}
getfield(f,:arr), Δ -> (BaseField{B}(Δ, f.metadata),)
end
@adjoint function Zygote.literal_getproperty(f::BaseField{B,M,T}, ::Val{:arr}) where {B<:SpatialBasis{Fourier},M,T}
getfield(f,:arr), Δ -> (BaseField{B}(Δ ./ adapt(typeof(Δ), T.(rfft_degeneracy_fac(f.Ny) ./ Zfac(B(), f.metadata))), f.metadata),)
end
Zygote.accum(f::BaseField, nt::NamedTuple{(:arr,:metadata)}) = (@assert(isnothing(nt.arr)); f)
# FieldTuple
@adjoint (::Type{FT})(fs) where {FT<:FieldTuple} = FT(fs), Δ -> (Δ.fs,)
@adjoint Zygote.literal_getproperty(f::FieldTuple, ::Val{:fs}) = getfield(f,:fs), Δ -> (FieldTuple(map((f,f̄) -> isnothing(f̄) ? zero(f) : f̄, getfield(f,:fs), Δ)),)
# BatchedReals
@adjoint Zygote.literal_getproperty(br::BatchedReal, ::Val{:vals}) = getfield(br,:vals), Δ -> (batch(real.(Δ)),)
## Field algebra
# Zygote has lots of rules for AbstractVectors / AbstractMatrices that
# don't quite work right due to the auto-basis conversions done for
# Fields, or which work right but trigger scalar indexing (thus don't
# work on GPU). this leads us to need a few more custom rules below
# than might be ideal, although its not too bad. see also:
# https://discourse.julialang.org/t/how-to-deal-with-zygote-sometimes-pirating-its-own-adjoints-with-worse-ones
# ℝᴺ -> ℝ¹
@adjoint sum(f::Field{B}) where {B} = sum(f), Δ -> (Δ*one(f),)
@adjoint norm(f::Field) = Zygote.pullback(f->sqrt(dot(f,f)), f)
@adjoint dot(f::Field{B1}, g::Field{B2}) where {B1,B2} = dot(f,g), Δ -> (Δ*B1(g), Δ*B2(f))
@adjoint (*)(f::Adjoint{<:Any,<:Field}, g::Field) = Zygote.pullback((f,g)->dot(f',g),f,g)
# ℝᴺˣᴺ -> ℝ¹
@adjoint logdet(L::ParamDependentOp, θ) = Zygote._pullback(θ->logdet(L(;θ...)), θ)
@adjoint logdet(L::DiagOp) = logdet(L), Δ -> (Δ * Zfac(L) * pinv(L)',)
# basis conversion
@adjoint (::Type{B})(f::Field{B′}) where {B<:Basis, B′} = B(f), Δ -> (B′(Δ),)
# algebra
@adjoint (+)(f::Field{B1}, g::Field{B2}) where {B1,B2} = f+g, Δ -> (B1(Δ), B2(Δ))
@adjoint (*)(a::Real, L::DiagOp) = a*L, Δ -> (tr(L'*Δ)/Zfac(L), a*Δ)
@adjoint (*)(L::DiagOp, a::Real) = a*L, Δ -> (a*Δ, tr(L'*Δ)/Zfac(L))
# operators
@adjoint function (*)(D::DiagOp{<:Field{B}}, v::Field{B′}) where {B,B′}
D*v, Δ -> (Diagonal(B(Δ) .* conj.(B(v))), B′(D'*Δ))
end
@adjoint function (\)(D::DiagOp{<:Field{B}}, v::Field{B′}) where {B,B′}
z = D \ v
function back(Δ)
v̄ = D' \ Δ
-Diagonal(B(v̄) .* conj.(B(z))), B′(v̄)
end
z, back
end
# this makes it so we only have to define a L*f and L\f adjoint (like
# above) for any new operators, and we get f'*L and f'/L for free
# without necessarily needing to fully implement the AbstractMatrix
# interface for L
@adjoint *(f::Adjoint{<:Any,<:Field}, L::Union{DiagOp,ImplicitOp}) = Zygote.pullback((f,L)->(L'*f')', f, L)
@adjoint /(f::Adjoint{<:Any,<:Field}, L::Union{DiagOp,ImplicitOp}) = Zygote.pullback((f,L)->(L'\f')', f, L)
# special case for some ops which are constant by definition
@adjoint *(L::Union{FuncOp,DiagOp{<:∇diag}}, f::Field{B}) where {B} = L*f, Δ->(nothing, B(L'*Δ))
## FieldVectors
# following two definitions are almost definitely not totally right w.r.t
# putting stuff in the correct basis, although they're working for everything
# I've needed thus far
@adjoint function *(x::FieldOrOpRowVector, y::FieldVector)
z = x * y
# when x is a vector of Fields
back(Δ) = (Δ * y', x' * Δ)
# when x is a vector of Diagonals. in this case, Δ * basis(Δ)(y)' create an
# OuterProdOp in the same basis as the Diagonals in x
back(Δ::Field{B}) where {B} = (Δ * basis(Δ)(y)'), (x' * Δ)
z, back
end
@adjoint function *(A::FieldOrOpMatrix, x::FieldOrOpVector)
z = A * x
back(Δ::FieldVector) = (B=basis(eltype(z)); (B(Δ) * B(x)', A' * Δ))
back(Δ::FieldOrOpVector) = (Δ * x', A' * Δ) # not sure if anything special needed here?
back(Δ) = (Δ * x', A' * Δ)
z, back
end
# don't know why Zygote's default adjoint for this breaks in various ways but this is simple enough
@adjoint +(I::UniformScaling, L::Union{FieldOp, FieldOrOpMatrix}) = I+L, Δ->(nothing, Δ)
# Zygote/lib/array.jl:311 would suggest this should be:
# M⁻¹, Δ->(-M⁻¹' * Δ * M⁻¹' + (- M * M⁻¹ * Δ' * M⁻¹ * M⁻¹' + Δ' * M⁻¹ * M⁻¹') + (M⁻¹' * M⁻¹ * Δ' - M⁻¹' * M⁻¹ * Δ' * M⁻¹ * M),)
# I haven't derived their version, but numerically the one gives the right answer where as their doesn't...
@adjoint function pinv(L::Union{FieldOp, FieldOrOpMatrix})
L⁻¹ = pinv(L)
L⁻¹, Δ->(-L⁻¹' * Δ * L⁻¹',)
end
@adjoint sqrt(L::DiagOp) = (z=sqrt(L);), Δ -> ((pinv(z)/2)'*Δ,)
# some stuff which arguably belongs in Zygote or ChainRules
# see also: https://github.com/FluxML/Zygote.jl/issues/316
@adjoint broadcasted(::typeof(\), x ::Numeric, y::Numeric) =
broadcast(\, x, y), Δ -> (nothing, unbroadcast(x, @. -Δ*y/x^2), unbroadcast(y, @. Δ/x))
@adjoint (::Type{SA})(tup) where {SA<:SArray} = SA(tup), Δ->(tuple(Δ...),)
# workaround for https://github.com/FluxML/Zygote.jl/issues/686
@static if versionof(Zygote) > v"0.4.15"
Zygote._zero(xs::StaticArray, T) = SizedArray{Tuple{size(xs)...},Union{T,Nothing}}(map(_->nothing, xs))
end
# workaround for Zygote not working through cat when dims is a Val
# adapted from solution by Seth Axen
# see https://github.com/FluxML/Zygote.jl/pull/881
@adjoint function cat(Xs::AbstractArray...; dims)
cat(Xs...; dims = dims), Δ -> begin
start = ntuple(_ -> 0, ndims(Δ))
catdims = Base.dims2cat(dims)
dXs = map(Xs) do x
move = ntuple(d -> (d<=length(catdims) && catdims[d]) ? size(x,d) : 0, ndims(Δ))
x_in_Δ = ntuple(d -> (d<=length(catdims) && catdims[d]) ? (start[d]+1:start[d]+move[d]) : Colon(), ndims(Δ))
start = start .+ move
dx = Δ[x_in_Δ...]
end
end
end
@adjoint adapt(to, x::A) where {A<:AbstractArray} = adapt(to, x), Δ -> (nothing, adapt(A, Δ))
# finite difference Hessian using Zygote gradients
# todo: delete, just use FiniteDifferences
function hessian(f, xs::Vector; ε=1f-3)
hcat(finite_difference(xs->vcat(gradient(f,xs)[1]...),xs,ε=ε)...)
end
function finite_difference(f, xs::Vector; ε=1f-3, progress=false)
@showprogress (progress ? 1 : Inf) map(1:length(xs)) do i
xs₊ = copy(xs); xs₊[i] += ε
xs₋ = copy(xs); xs₋[i] -= ε
(f(xs₊) .- f(xs₋)) ./ (2ε)
end
end