In [None]:
using Revise
using CMBLensing, CUDA, Distributions, LinearAlgebra, Plots, Setfield
Plots.default(
    lw=2, framestyle=:box, minorticks=false, color_palette=Plots.palette(:tab10),
    size=(525,400), tickfontsize=12, label="", msw=0
)
plotlyjs();

# Model -> MUSE

In [2]:
proj = ProjLambert(Ny=1024, Nx=2048, θpix=2.25, T=Float32, storage=CuArray);

In [3]:
Cℓ = camb(r=0.1);

In [4]:
Cϕ = Cℓ_to_Cov(:I, proj, Cℓ.total.ϕϕ);

In [5]:
Cf = Cℓ_to_Cov(:P, proj, Cℓ.unlensed_total.EE, Cℓ.unlensed_total.BB);

In [6]:
Cn = Cℓ_to_Cov(:P, proj, noiseCℓs(μKarcminT=4.5, ℓknee=0).TT, noiseCℓs(μKarcminT=4.5, ℓknee=0).TT);

In [7]:
ϕ = simulate(Cϕ);

In [8]:
f = simulate(Cf);

In [9]:
n = simulate(Cn);

In [10]:
plot_zoom((LenseFlow(ϕ) * f + n)[:E]);

In [None]:
@kwdef mutable struct MyDataset4 <: DataSet
    d
    Cf
    Cϕ
    Cn
    M = MidPass(300,4000)

    # boilerplate: 
    L = LenseFlow(7)
    Cf̃ = Cf
    Cn̂ = Cn
    G = I
    B̂ = I
    M̂ = I
    D = sqrt((Cf + (I*Float32(deg2rad(5/60)^2) + 2*Cn)) * pinv(Cf))
end

In [None]:
quadratic_estimate(ds)

In [12]:
@fwdmodel function (ds::MyDataset5)(; f, ϕ, θ=nothing, d=ds.d)
    f ~ MvNormal(0, ds.Cf)
    ϕ ~ MvNormal(0, ds.Cϕ)
    f_rot = FlatEBMap(
         f[:E] * cos(2θ) + f[:B] * sin(2θ),
        -f[:E] * sin(2θ) + f[:B] * cos(2θ)
    )
    d ~ MvNormal(ds.M * (LenseFlow(ϕ) * f_rot), ds.Cn)
end

CMBLensing.logprior(ds::MyDataset5; _...) = 0

In [78]:
@fwdmodel function (ds::MyDataset4)(; f, ϕ, θ=nothing, d=ds.d)
    (;Cf, Cϕ, Cn, M, L) = ds
    f ~ MvNormal(0, Cf)
    ϕ ~ MvNormal(0, Cϕ)
    d ~ MvNormal(M * (L(ϕ) * f), Cn)
end

CMBLensing.logprior(ds::MyDataset3; _...) = 0

In [79]:
ds = MyDataset4(;d=nothing, Cf, Cϕ, Cn);

# mMAP vs jMAP

In [34]:
(;ds) = load_sim(
    θpix      = 2.25,
    Nside     = 512,
    pol       = :P,
    T         = Float32, 
    storage   = CuArray,
    beamFWHM  = 1.5,
    μKarcminT = 3,
    pixel_mask_kwargs = (edge_padding_deg=1.5, edge_rounding_deg=0.5, apodization_deg=0.5, num_ptsrcs=0)
);

In [None]:
plot(diag(ds.M[2])[:Q])

In [36]:
(;d, f, ϕ) = simulate(ds);
ds.d = d;

In [None]:
jMAP = @time MAP_joint(ds, nsteps=30, progress=false);

In [None]:
plot(first.(jMAP.history))

In [None]:
mMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=30, Nsims=30, nsteps=30, α=0.5, progress=false);

In [None]:
mffmMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);

In [None]:
plot(get_Cℓ(∇²*ϕ), label="true")
plot!(get_Cℓ(∇²*jMAP.ϕ), label="jMAP")
plot!(get_Cℓ(∇²*mMAP.ϕ), label="mMAP")
plot!(get_Cℓ(∇²*mffmMAP.ϕ), label="mffmMAP")
plot!(yscale=:log10, ylim=(1e-10,1e-5), xlim=(0,2000))

In [None]:
plot(get_ρℓ(∇²*ϕ,∇²*jMAP.ϕ,), label="true x jMAP")
plot!(get_ρℓ(∇²*ϕ,∇²*mMAP.ϕ), label="true x mMAP")
plot!(get_ρℓ(∇²*jMAP.ϕ,∇²*mMAP.ϕ), label="mMAP x jMAP")
plot!(get_ρℓ(∇²*jMAP.ϕ,∇²*mffmMAP.ϕ), label="mffmMAP x jMAP")
plot!(ticks=:native)

In [None]:
plot(
    plot(∇² * jMAP.ϕ, title = "jMAP"),
    plot(∇² * mffmMAP.ϕ, title = "mffmMAP"),
    plot(∇² * mMAP.ϕ, title = "mMAP"),
    layout = (1, 3), size = (900,300), cbar = false
)

# Lenspyx AD