Skip to content

Commit

Permalink
AD improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Jan 7, 2023
1 parent 21d6a47 commit 162b53d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
5 changes: 3 additions & 2 deletions src/autodiff.jl
Expand Up @@ -59,8 +59,9 @@ end
end
# preserve field type for sub-component property getters
function _getproperty_subcomponent_pullback(f, k)
g = zero(f)
function getproperty_pullback(Δ)
g = similar(f, promote_type(eltype(f), eltype(Δ)))
g .= 0
getproperty(g, k) .= Δ
(g, nothing)
end
Expand Down Expand Up @@ -306,4 +307,4 @@ AbstractFFTs.plan_rfft(arr::AbstractArray{<:Dual}, region; kws...) = plan_rfft(v


# to allow stuff like Float32(::Dual) to work
(::Type{S})(x::Dual{T,V,N}) where {T,V,N,S<:Union{Float32,Float64}} = Dual{T,S,N}(S(value(x)), Partials(ntuple(i -> S(partials(x,i)), Val(N))))
# (::Type{S})(x::Dual{T,V,N}) where {T,V,N,S<:Union{Float32,Float64}} = Dual{T,S,N}(S(value(x)), Partials(ntuple(i -> S(partials(x,i)), Val(N))))
25 changes: 16 additions & 9 deletions src/proj_healpix.jl
Expand Up @@ -142,18 +142,25 @@ end

# some NFFT stuff needed for method=:fft projections
cu_nfft_loaded = false
@init begin
@require NFFT="efe261a4-0d2b-5849-be55-fc731d526b0d" begin
using .NFFT: plan_nfft, AbstractNFFTPlan
Zygote.@adjoint function *(plan::Union{Adjoint{<:Any,<:AbstractNFFTPlan}, AbstractNFFTPlan}, x::AbstractArray{T}) where {T}
function mul_nfft_plan_pullback(Δ)
(nothing, T.(adjoint(plan) * complex(Δ)))
@init @require NFFT="efe261a4-0d2b-5849-be55-fc731d526b0d" begin
using .NFFT: plan_nfft, AbstractNFFTPlan
Zygote.@adjoint function *(plan::Union{Adjoint{<:Any,<:AbstractNFFTPlan}, AbstractNFFTPlan}, x::AbstractArray{T}) where {T}
function mul_nfft_plan_pullback(Δ)
(nothing, adjoint(plan) * complex(Δ))
end
plan * x, mul_nfft_plan_pullback
end
for P in [:(AbstractNFFTPlan{S}), :(Adjoint{Complex{S},<:AbstractNFFTPlan{S}})]
for op in [:(Base.:*), :(Base.:\)]
for D in [1, 2] # need explicit dimension to resolve method ambiguity
@eval function ($op)(plan::$P, arr::AbstractArray{<:Complex{<:Dual{T}}, $D}) where {T, S}
arr_of_duals(T, apply_plan($op, plan, arr)...)
end
end
plan * x, mul_nfft_plan_pullback
end
end
@require CuNFFT="a9291f20-7f4c-4d50-b30d-4e07b13252e1" global cu_nfft_loaded = true
end
end
@init @require CuNFFT="a9291f20-7f4c-4d50-b30d-4e07b13252e1" global cu_nfft_loaded = true


@doc doc"""
Expand Down
6 changes: 3 additions & 3 deletions src/proj_lambert.jl
Expand Up @@ -379,7 +379,7 @@ function Cℓ_to_Cov(::Val{:I}, proj::ProjLambert{T,V}, (Cℓ, ℓedges, θname)
ℓbin_indices = findbin.(Ref(adapt(proj.storage, ℓedges)), proj.ℓmag)
Cov(θ) = Diagonal(LambertFourier(bandpower_rescale(C₀.diag.arr, ℓbin_indices, θ), proj))
ParamDependentOp(@eval Main let Cov=$Cov
(;$θname=$(ones(T,length(ℓedges)-1)), _...) -> Cov($T.($θname))
(;$θname=$(ones(T,length(ℓedges)-1)), _...) -> Cov($θname)
end)
end

Expand All @@ -388,7 +388,7 @@ function Cℓ_to_Cov(::Val{:P}, proj::ProjLambert{T}, (CℓEE, ℓedges, θname)
ℓbin_indices = findbin.(Ref(adapt(proj.storage, ℓedges)), proj.ℓmag)
Cov(θ) = Diagonal(LambertEBFourier(bandpower_rescale(C₀.diag.El, ℓbin_indices, θ), one(eltype(θ)) .* C₀.diag.Bl, proj))
ParamDependentOp(@eval Main let Cov=$Cov
(;$θname=$(ones(T,length(ℓedges)-1)), _...) -> Cov($T.($θname))
(;$θname=$(ones(T,length(ℓedges)-1)), _...) -> Cov($θname)
end)
end

Expand All @@ -397,7 +397,7 @@ function findbin(ℓedges, ℓ; out_of_range=length(ℓedges))
(ℓ<ℓedges[1] ||>=ℓedges[end]) ? out_of_range : findfirst(>(ℓ), ℓedges)::Int - 1
end
function bandpower_rescale(arr::A, ℓbin_indices, amplitudes) where {T<:Real, A<:AbstractArray{T}}
amplitudes_arr = adapt(basetype(A), [T.(amplitudes); 1])
amplitudes_arr = adapt(basetype(A), [amplitudes; 1])
return amplitudes_arr[ℓbin_indices] .* arr
end

Expand Down

0 comments on commit 162b53d

Please sign in to comment.