-
Notifications
You must be signed in to change notification settings - Fork 10
/
flowops.jl
68 lines (52 loc) · 2.72 KB
/
flowops.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
abstract type FlowOp{I,t₀,t₁,T} <: ImplicitOp{T} end
abstract type FlowOpWithAdjoint{I,t₀,t₁,T} <: FlowOp{I,t₀,t₁,T} end
# interface
function velocity end
function velocityᴴ end
function negδvelocityᴴ end
# define integrations for L*f, L'*f, L\f, and L'\f
*(Lϕ:: FlowOp{I,t₀,t₁}, f::Field) where {I,t₀,t₁} = @⌛ odesolve(I, velocity(cache(Lϕ, f),f)..., t₀, t₁)
*(Lϕ::Adjoint{<:Any,<:FlowOp{I,t₀,t₁}}, f::Field) where {I,t₀,t₁} = @⌛ odesolve(I, velocityᴴ(cache(Lϕ',f),f)..., t₁, t₀)
\(Lϕ:: FlowOp{I,t₀,t₁}, f::Field) where {I,t₀,t₁} = @⌛ odesolve(I, velocity(cache(Lϕ, f),f)..., t₁, t₀)
\(Lϕ::Adjoint{<:Any,<:FlowOp{I,t₀,t₁}}, f::Field) where {I,t₀,t₁} = @⌛ odesolve(I, velocityᴴ(cache(Lϕ',f),f)..., t₀, t₁)
@adjoint (::Type{L})(ϕ) where {L<:FlowOp} = L(ϕ), Δ -> (Δ,)
@adjoint (Lϕ::FlowOp)(ϕ′) = Lϕ(ϕ′), Δ -> (nothing, Δ)
# for FlowOps (without adjoint), use Zygote to take a gradient through the ODE solver
@adjoint *(Lϕ::FlowOp{I,t₀,t₁}, f::Field{B}) where {I,t₀,t₁,B} =
Zygote.pullback((Lϕ,f)->odesolve(I, velocity(cache(Lϕ, f),f)..., t₀, t₁), Lϕ, f)
@adjoint \(Lϕ::FlowOp{I,t₀,t₁}, f::Field{B}) where {I,t₀,t₁,B} =
Zygote.pullback((Lϕ,f)->odesolve(I, velocity(cache(Lϕ, f),f)..., t₁, t₀), Lϕ, f)
# FlowOpWithAdjoint provide their own velocity for computing the gradient
# note the weird use of task_local_storage below is b/c if we dont
# need the pullback w.r.t. ϕ, we can save time by running the
# transpose flow, rather than the transpose-δ flow. however, Zygote
# has no capibility to know that here in the code, so we use this ugly
# hack which requires the higher level function to have specified :ϕ
# is constant by setting a task_local_storage. this can be made much
# more clean once we switch to Diffractor
@adjoint function *(Lϕ::FlowOpWithAdjoint{I,t₀,t₁}, f::Field{B}) where {I,t₀,t₁,B}
cLϕ = cache(Lϕ,f)
f̃ = cLϕ * f
function back(Δ)
if :ϕ in get(task_local_storage(), :AD_constants, ())
nothing, B(cLϕ' * Δ)
else
(_,δf,δϕ) = @⌛ odesolve(I, negδvelocityᴴ(cLϕ, FieldTuple(f̃,Δ))..., t₁, t₀)
δϕ, B(δf)
end
end
f̃, back
end
@adjoint function \(Lϕ::FlowOpWithAdjoint{I,t₀,t₁}, f̃::Field{B}) where {I,t₀,t₁,B}
cLϕ = cache(Lϕ,f̃)
f = cLϕ \ f̃
function back(Δ)
if :ϕ in get(task_local_storage(), :AD_constants, ())
nothing, B(cLϕ' \ Δ)
else
(_,δf,δϕ) = @⌛ odesolve(I, negδvelocityᴴ(cLϕ, FieldTuple(f,Δ))..., t₀, t₁)
δϕ, B(δf)
end
end
f, back
end