Skip to content

Commit

Permalink
muse implicit diff working
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Jan 4, 2023
1 parent 89ce8f5 commit d79184d
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 34 deletions.
23 changes: 15 additions & 8 deletions src/base_fields.jl
Expand Up @@ -36,7 +36,9 @@ lastindex(f::BaseField, i::Int) = lastindex(f.arr, i)
@propagate_inbounds getindex(f::BaseField, I::Union{Int,Colon,AbstractArray}...) = getindex(f.arr, I...)
@propagate_inbounds setindex!(f::BaseField, X, I::Union{Int,Colon,AbstractArray}...) = (setindex!(f.arr, X, I...); f)
similar(f::BaseField{B}, ::Type{T}) where {B,T} = BaseField{B}(similar(f.arr, T), f.metadata)
similar(f::BaseField{B}, ::Type{T}, dims::Base.DimOrInd...) where {B,T} = similar(f.arr, T, dims...)
copy(f::BaseField{B}) where {B} = BaseField{B}(copy(f.arr), f.metadata)
copyto!(dst::AbstractArray, src::BaseField) = copyto!(dst, src.arr)
(==)(f₁::BaseField, f₂::BaseField) = strict_compatible_metadata(f₁,f₂) && (f₁.arr == f₂.arr)


Expand All @@ -46,7 +48,9 @@ function promote(f₁::BaseField{B₁}, f₂::BaseField{B₂}) where {B₁,B₂}
B = typeof(promote_basis_generic(B₁(), B₂()))
B(f₁), B(f₂)
end

# allow very basic arithmetic with BaseField & AbstractArray
promote(f::BaseField{B}, x::AbstractArray) where {B} = (f, BaseField{B}(reshape(x, size(f.arr)), f.proj))
promote(x::AbstractArray, f::BaseField{B}) where {B} = reverse(promote(f, x))

## broadcasting

Expand All @@ -61,6 +65,7 @@ BroadcastStyle(::Type{F}) where {B,M,T,A,F<:BaseField{B,M,T,A}} =
BroadcastStyle(::BaseFieldStyle{S₁,B₁}, ::BaseFieldStyle{S₂,B₂}) where {S₁,B₁,S₂,B₂} =
BaseFieldStyle{typeof(result_style(S₁(), S₂())), typeof(promote_basis_strict(B₁(),B₂()))}()
BroadcastStyle(S::BaseFieldStyle, ::DefaultArrayStyle{0}) = S
BaseFieldStyle{S,B}(::Val{2}) where {S,B} = DefaultArrayStyle{2}()

# with the Broadcasted object created, we now compute the answer
function materialize(bc::Broadcasted{BaseFieldStyle{S,B}}) where {S,B}
Expand Down Expand Up @@ -101,10 +106,13 @@ function materialize!(dst::BaseField{B}, bc::Broadcasted{BaseFieldStyle{S,B′}}

end

# the default preprocessing, which just unwraps the underlying array.
# this doesn't dispatch on the first argument, but custom BaseFields
# are free to override this and dispatch on it if they need
preprocess(::Any, f::BaseField) = f.arr
# if broadcasting into a BaseField, the first method here is hit with
# dest::Tuple{BaseFieldStyle,M}, in which case just unwrap the array,
# since it will be fed into a downstream regular broadcast
preprocess(::Tuple{BaseFieldStyle{S,B},M}, f::BaseField) where {S,B,M} = f.arr
# if broadcasting into an Array (ie dropping the BaseField wrapper) we
# need to return the vector representation
preprocess(::AbstractArray, f::BaseField) = view(f.arr, :)

# we re-wrap each Broadcasted object as we go through preprocessing
# because some array types do special things here (e.g. CUDA wraps
Expand Down Expand Up @@ -135,8 +143,7 @@ function strict_compatible_metadata(f₁::BaseField, f₂::BaseField)
end

## mapping

# this comes up in Zygote.broadcast_forward, and the generic falls back to a regular Array
# map over entries in the array like a true AbstractArray
map(func, f::BaseField{B}) where {B} = BaseField{B}(map(func, f.arr), f.metadata)


Expand Down Expand Up @@ -169,4 +176,4 @@ getproperty(f::BaseField{B}, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B)))...}) where
BaseField{B₀}(_reshape_batch(view(getfield(f,:arr), pol_slice(f, pol_index(B(), k))...)), getfield(f,:metadata))
getproperty(f::BaseS02{Basis3Prod{𝐈,B₂,B₀}}, ::Val{:P}) where {B₂,B₀} =
BaseField{Basis2Prod{B₂,B₀}}(view(getfield(f,:arr), pol_slice(f, 2:3)...), getfield(f,:metadata))
getproperty(f::BaseS2, ::Val{:P}) = f
getproperty(f::BaseS2, ::Val{:P}) = f
21 changes: 20 additions & 1 deletion src/field_tuples.jl
Expand Up @@ -28,13 +28,15 @@ typealias_def(::Type{<:FieldTuple{FS,T}}) where {FS<:Tuple,T} =
### array interface
size(f::FieldTuple) = (mapreduce(length, +, f.fs, init=0),)
copy(f::FieldTuple) = FieldTuple(map(copy,f.fs))
copyto!(dst::AbstractArray, src::FieldTuple) = copyto!(dst, src[:]) # todo: memory optimization possible
iterate(ft::FieldTuple, args...) = iterate(ft.fs, args...)
getindex(f::FieldTuple, i::Union{Int,UnitRange}) = getindex(f.fs, i)
fill!(ft::FieldTuple, x) = (map(f->fill!(f,x), ft.fs); ft)
get_storage(f::FieldTuple) = only(unique(map(get_storage, f.fs)))
adapt_structure(to, f::FieldTuple) = FieldTuple(map(f->adapt(to,f),f.fs))
similar(ft::FieldTuple) = FieldTuple(map(similar,ft.fs))
similar(ft::FieldTuple, ::Type{T}) where {T<:Number} = FieldTuple(map(f->similar(f,T),ft.fs))
similar(ft::FieldTuple, ::Type{T}, dims::Base.DimOrInd...) where {B,T} = similar(ft.fs[1].arr, T, dims...) # todo: make work for heterogenous arrays?
similar(ft::FieldTuple, Nbatch::Int) = FieldTuple(map(f->similar(f,Nbatch),ft.fs))
sum(f::FieldTuple; dims=:) = dims == (:) ? sum(sum, f.fs) : error("sum(::FieldTuple, dims=$dims not supported")

Expand All @@ -54,6 +56,7 @@ function BroadcastStyle(::FieldTupleStyle{S₁,Names}, ::FieldTupleStyle{S₂,Na
FieldTupleStyle{Tuple{map_tupleargs((s₁,s₂)->typeof(result_style(s₁(),s₂())), S₁, S₂)...}, Names}()
end
BroadcastStyle(S::FieldTupleStyle, ::DefaultArrayStyle{0}) = S
FieldTupleStyle{S,Names}(::Val{2}) where {S,Names} = DefaultArrayStyle{2}()


@generated function materialize(bc::Broadcasted{FieldTupleStyle{S,Names}}) where {S,Names}
Expand All @@ -73,13 +76,29 @@ end
struct FieldTupleComponent{i} end

preprocess(::Tuple{<:Any,FieldTupleComponent{i}}, ft::FieldTuple) where {i} = ft.fs[i]
preprocess(::AbstractArray, ft::FieldTuple) = vcat((view(f.arr, :) for f in ft.fs)...)


### mapping
# map over entries in the component fields like a true AbstractArray
map(func, ft::FieldTuple) = FieldTuple(map(f -> map(func, f), ft.fs))

### promotion
function promote(ft1::FieldTuple, ft2::FieldTuple)
fts = map(promote, ft1.fs, ft2.fs)
FieldTuple(map(first,fts)), FieldTuple(map(last,fts))
end
# allow very basic arithmetic with FieldTuple & AbstractArray
function promote(ft::FieldTuple, x::AbstractVector)
lens = map(length, ft.fs)
offsets = typeof(lens)((cumsum([1; lens...])[1:end-1]...,))
x_ft = FieldTuple(map(ft.fs, offsets, lens) do f, offset, len
promote(f, view(x, offset:offset+len-1))[2]
end)
(ft, x_ft)
end
promote(x::AbstractVector, ft::FieldTuple) = reverse(promote(ft, x))


### conversion
Basis(ft::FieldTuple) = ft
Expand Down Expand Up @@ -120,4 +139,4 @@ tr(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(tr∘Diagon
batch_length(ft::FieldTuple) = only(unique(map(batch_length, ft.fs)))
batch_index(ft::FieldTuple, I) = FieldTuple(map(f -> batch_index(f, I), ft.fs))
getindex(ft::FieldTuple, k::Symbol) = ft.fs[k]
haskey(ft::FieldTuple, k::Symbol) = haskey(ft.fs, k)
haskey(ft::FieldTuple, k::Symbol) = haskey(ft.fs, k)
13 changes: 10 additions & 3 deletions src/generic.jl
Expand Up @@ -330,9 +330,16 @@ show_vector(io::IO, f::Field) = !isempty(f) && show_vector(io, f[:])
Base.has_offset_axes(::Field) = false # needed for Diagonal(::Field) if the Field is implicitly-sized


# addition/subtraction works between any fields and scalars, promotion is done
# automatically if fields are in different bases
for op in (:+,:-), (T1,T2,promote) in ((:Field,:Scalar,false),(:Scalar,:Field,false),(:Field,:Field,true))
# addition/subtraction works between fields, scalars, and
# abstractarrays. promotion is done automatically for fields in
# different bases are wrapped assuming they're the same field type
for op in (:+,:-), (T1,T2,promote) in [
(:Field, :Scalar, false),
(:Scalar, :Field, false),
(:Field, :Field, true),
(:Field, :AbstractArray, true),
(:AbstractArray, :Field, true)
]
@eval ($op)(a::$T1, b::$T2) = broadcast($op, ($promote ? promote(a,b) : (a,b))...)
end

Expand Down
40 changes: 23 additions & 17 deletions src/muse.jl
Expand Up @@ -2,7 +2,8 @@
# interface with MuseInference.jl

using .MuseInference: AbstractMuseProblem, MuseResult
import .MuseInference: ∇θ_logLike, sample_x_z, ẑ_at_θ, muse!, standardizeθ
using .MuseInference.AbstractDifferentiation
import .MuseInference: logLike, ∇θ_logLike, sample_x_z, ẑ_at_θ, muse!, standardizeθ

export CMBLensingMuseProblem

Expand All @@ -14,10 +15,20 @@ struct CMBLensingMuseProblem{DS<:DataSet,DS_SIM<:DataSet} <: AbstractMuseProblem
θ_fixed
x
latent_vars
autodiff
end

function CMBLensingMuseProblem(ds, ds_for_sims=ds; parameterization=0, MAP_joint_kwargs=(;), θ_fixed=(;), latent_vars=nothing)
CMBLensingMuseProblem(ds, ds_for_sims, parameterization, MAP_joint_kwargs, θ_fixed, ds.d, latent_vars)
function CMBLensingMuseProblem(
ds,
ds_for_sims = ds;
parameterization = 0,
MAP_joint_kwargs = (;),
θ_fixed = (;),
latent_vars = nothing,
autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(tag=false), AD.ZygoteBackend())),
)
parameterization == 0 || error("only parameterization=0 (unlensed parameterization) currently implemented")
CMBLensingMuseProblem(ds, ds_for_sims, parameterization, MAP_joint_kwargs, θ_fixed, ds.d, latent_vars, autodiff)
end

mergeθ(prob::CMBLensingMuseProblem, θ) = isempty(prob.θ_fixed) ? θ : (;prob.θ_fixed..., θ...)
Expand All @@ -27,26 +38,21 @@ function standardizeθ(prob::CMBLensingMuseProblem, θ)
1f0 * ComponentVector(θ) # ensure component vector and float
end

function MuseInference.logLike(prob::CMBLensingMuseProblem, d, z, θ)
logpdf(prob.ds; z..., θ = mergeθ(prob, θ), d)
end

function ∇θ_logLike(prob::CMBLensingMuseProblem, d, z, θ)
@unpack ds, parameterization = prob
@set! ds.d = d
if parameterization == 0
gradient-> logpdf(ds; z..., θ = mergeθ(prob, θ)), θ)[1]
elseif parameterization == :mix
= mix(ds; z..., θ = mergeθ(prob, θ))
gradient-> logpdf(Mixed(ds); z°..., θ = mergeθ(prob, θ)), θ)[1]
else
error("parameterization should be 0 or :mix")
end
AD.gradient(prob.autodiff, θ -> logLike(prob, d, z, θ), θ)[1]
end

function sample_x_z(prob::CMBLensingMuseProblem, rng::AbstractRNG, θ)
sim = simulate(rng, prob.ds_for_sims, θ = mergeθ(prob, θ))
if prob.latent_vars == nothing
# this is a guess which might not work for everything necessarily
z = FieldTuple(delete(sim, (:f̃, :d, )))
z = LenseBasis(FieldTuple(delete(sim, (:f̃, :d, ))) )
else
z = FieldTuple(select(sim, prob.latent_vars))
z = LenseBasis(FieldTuple(select(sim, prob.latent_vars)))
end
x = sim.d
(;x, z)
Expand All @@ -56,12 +62,12 @@ function ẑ_at_θ(prob::CMBLensingMuseProblem, d, zguess, θ; ∇z_logLike_atol
@unpack ds = prob
Ωstart = delete(NamedTuple(zguess), :f)
MAP = MAP_joint(mergeθ(prob, θ), @set(ds.d=d), Ωstart; fstart=zguess.f, prob.MAP_joint_kwargs...)
FieldTuple(;delete(MAP, :history)...), MAP.history
LenseBasis(FieldTuple(;delete(MAP, :history)...)), MAP.history
end

function ẑ_at_θ(prob::CMBLensingMuseProblem{<:NoLensingDataSet}, d, (f₀,), θ; ∇z_logLike_atol=nothing)
@unpack ds = prob
FieldTuple(f=argmaxf_logpdf(I, mergeθ(prob, θ), @set(ds.d=d); fstart=f₀, prob.MAP_joint_kwargs...)), nothing
LenseBasis(FieldTuple(f=argmaxf_logpdf(I, mergeθ(prob, θ), @set(ds.d=d); fstart=f₀, prob.MAP_joint_kwargs...))), nothing
end

function muse!(result::MuseResult, ds::DataSet, θ₀=nothing; parameterization=0, MAP_joint_kwargs=(;), kwargs...)
Expand Down
10 changes: 5 additions & 5 deletions src/proj_lambert.jl
Expand Up @@ -131,17 +131,17 @@ promote_metadata_generic(metadata₁::ProjLambert, metadata₂::ProjLambert) =
# return `Broadcasted` objects which are spliced into the final
# broadcast, thus avoiding allocating any temporary arrays.

function preprocess((_,proj)::Tuple{<:Any,<:ProjLambert{T,V}}, r::Real) where {T,V}
function preprocess((_,proj)::Tuple{<:BaseFieldStyle,<:ProjLambert{T,V}}, r::Real) where {T,V}
r isa BatchedReal ? adapt(V, reshape(r.vals, 1, 1, 1, :)) : r
end
# need custom adjoint here bc Δ can come back batched from the
# backward pass even though r was not batched on the forward pass
@adjoint function preprocess(m::Tuple{<:Any,<:ProjLambert{T,V}}, r::Real) where {T,V}
@adjoint function preprocess(m::Tuple{<:BaseFieldStyle,<:ProjLambert{T,V}}, r::Real) where {T,V}
preprocess(m, r), Δ -> (nothing, Δ isa AbstractArray ? batch(real.(Δ[:])) : Δ)
end


function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ∇d::∇diag) where {S,B}
function preprocess((_,proj)::Tuple{<:BaseFieldStyle{S,B},<:ProjLambert}, ∇d::∇diag) where {S,B}

(B <: Union{Fourier,QUFourier,IQUFourier}) ||
error("Can't broadcast ∇[$(∇d.coord)] as a $(typealias(B)), its not diagonal in this basis.")
Expand All @@ -156,15 +156,15 @@ function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ∇d::
end
end

function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ::²diag) where {S,B}
function preprocess((_,proj)::Tuple{<:BaseFieldStyle{S,B},<:ProjLambert}, ::²diag) where {S,B}

(B <: Union{Fourier,<:Basis2Prod{<:Any,Fourier},<:Basis3Prod{<:Any,<:Any,Fourier}}) ||
error("Can't broadcast a BandPass as a $(typealias(B)), its not diagonal in this basis.")

broadcasted(+, broadcasted(^, proj.ℓx', 2), broadcasted(^, proj.ℓy, 2))
end

function preprocess((_,proj)::Tuple{<:Any,<:ProjLambert}, bp::BandPass)
function preprocess((_,proj)::Tuple{<:BaseFieldStyle,<:ProjLambert}, bp::BandPass)
Cℓ_to_2D(bp.Wℓ, proj)
end

Expand Down

0 comments on commit d79184d

Please sign in to comment.