In [1]:
using Revise
using CMBLensing, CUDA, Distributions, LinearAlgebra, Plots, Setfield
Plots.default(
    lw=2, framestyle=:box, minorticks=false, color_palette=Plots.palette(:tab10),
    size=(525,400), tickfontsize=12, label="", msw=0
)
plotlyjs();

# Model -> MUSE

In [2]:
proj = ProjLambert(Ny=1024, Nx=2048, θpix=2.25, T=Float32, storage=CuArray);

In [3]:
Cℓ = camb(r=0.1);

In [4]:
Cϕ = Cℓ_to_Cov(:I, proj, Cℓ.total.ϕϕ);

In [5]:
Cf = Cℓ_to_Cov(:P, proj, Cℓ.unlensed_total.EE, Cℓ.unlensed_total.BB);

In [6]:
Cn = Cℓ_to_Cov(:P, proj, noiseCℓs(μKarcminT=4.5, ℓknee=0).TT, noiseCℓs(μKarcminT=4.5, ℓknee=0).TT);

In [7]:
ϕ = simulate(Cϕ);

In [8]:
f = simulate(Cf);

In [9]:
n = simulate(Cn);

In [10]:
plot_zoom((LenseFlow(ϕ) * f + n)[:E]);

In [None]:
@kwdef mutable struct MyDataset4 <: DataSet
    d
    Cf
    Cϕ
    Cn
    M = MidPass(300,4000)

    # boilerplate: 
    L = LenseFlow(7)
    Cf̃ = Cf
    Cn̂ = Cn
    G = I
    B̂ = I
    M̂ = I
    D = sqrt((Cf + (I*Float32(deg2rad(5/60)^2) + 2*Cn)) * pinv(Cf))
end

In [None]:
quadratic_estimate(ds)

In [12]:
@fwdmodel function (ds::MyDataset5)(; f, ϕ, θ=nothing, d=ds.d)
    f ~ MvNormal(0, ds.Cf)
    ϕ ~ MvNormal(0, ds.Cϕ)
    f_rot = FlatEBMap(
         f[:E] * cos(2θ) + f[:B] * sin(2θ),
        -f[:E] * sin(2θ) + f[:B] * cos(2θ)
    )
    d ~ MvNormal(ds.M * (LenseFlow(ϕ) * f_rot), ds.Cn)
end

CMBLensing.logprior(ds::MyDataset5; _...) = 0

In [78]:
@fwdmodel function (ds::MyDataset4)(; f, ϕ, θ=nothing, d=ds.d)
    (;Cf, Cϕ, Cn, M, L) = ds
    f ~ MvNormal(0, Cf)
    ϕ ~ MvNormal(0, Cϕ)
    d ~ MvNormal(M * (L(ϕ) * f), Cn)
end

CMBLensing.logprior(ds::MyDataset3; _...) = 0

In [79]:
ds = MyDataset4(;d=nothing, Cf, Cϕ, Cn);

# mMAP vs jMAP

In [242]:
(;ds) = load_sim(
    θpix      = 2.25,
    Nside     = 512,
    pol       = :P,
    T         = Float32, 
    storage   = CuArray,
    beamFWHM  = 1.5,
    μKarcminT = 3,
    pixel_mask_kwargs = (edge_padding_deg=1.5, edge_rounding_deg=0.5, apodization_deg=0.5, num_ptsrcs=0)
);

In [None]:
σ²κ = 1e-7
Mborder = ds.M[2][:Q]
T = real(eltype(ds.Cf))
ds.logprior = function(;ϕ, _...)
    -(sum(Mborder * (∇² * ϕ)) / sum(diag(Mborder)))^2 / T(2*σ²κ)
end

In [None]:
plot(diag(ds.M[2])[:Q])

In [245]:
(;d, f, ϕ) = simulate(ds);
ds.d = d;

In [None]:
JMAP_SSP = @time MAP_joint(ds, nsteps=30, progress=false, prior_deprojection_factor=0);

In [None]:
JMAP = @time MAP_joint(@set(ds.logprior=(;_...)->0), nsteps=30, progress=false, prior_deprojection_factor=0);

In [None]:
FJMAP_SSP = @time MAP_marg(ds, nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);

In [None]:
FJMAP_SSP_BL = @time MAP_marg(@set(ds.L=BilinearLens), nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);

In [None]:
MMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=30, Nsims=30, nsteps=30, α=0.5, progress=false);

In [None]:
plot(get_Cℓ(∇²*ϕ), label="True")
plot!(get_Cℓ(∇²*JMAP.ϕ), label="JMAP")
plot!(get_Cℓ(∇²*MMAP.ϕ), label="MMAP")
plot!(get_Cℓ(∇²*JMAP_SSP.ϕ), label="JMAP + SSP")
plot!(get_Cℓ(∇²*FJMAP_SSP.ϕ), label="FJMAP + SSP")
plot!(get_Cℓ(∇²*FJMAP_SSP_BL.ϕ), label="FJMAP + SSP + BL")
plot!(yscale=:log10, ylim=(1e-10,1e-5), xlim=(0,2000))

In [None]:
plot()
# plot!(get_ρℓ(∇²*ϕ,∇²*JMAP.ϕ,), label="true x JMAP")
# plot!(get_ρℓ(∇²*ϕ,∇²*MMAP.ϕ), label="true x MMAP")
# plot!(get_ρℓ(∇²*JMAP.ϕ,∇²*JMAP_SSP.ϕ), label="MMAP x JMAP")
plot!(get_ρℓ(∇²*FJMAP_SSP.ϕ,∇²*JMAP_SSP.ϕ), label="FJMAP+SSP x JMAP+SSP")
plot!(get_ρℓ(∇²*FJMAP_SSP_BL.ϕ,∇²*JMAP_SSP.ϕ), label="FJMAP+SSP+BL x JMAP+SSP")
plot!(ticks=:native)

In [None]:
plot(
    plot(∇² * JMAP.ϕ, title = "JMAP"),
    plot(∇² * JMAP_SSP.ϕ, title = "JMAP + SSP"),
    plot(∇² * FJMAP_SSP.ϕ, title = "FJMAP + SSP"),
    plot(∇² * MMAP.ϕ, title = "MMAP"),
    layout = (1, 4), size = (1100,300), cbar = false
)

# Lenspyx AD

# Benchmark

In [None]:
f = Map(f);

In [None]:
# map size (spin-0)
size(f.arr)

In [None]:
# LenseFlow precomputation
@btime CUDA.@sync precompute!!(LenseFlow(ϕ,10),f);

In [None]:
# applying the precomputed lensing operator
L = precompute!!(LenseFlow(ϕ,10),f)
@btime CUDA.@sync L * f;

In [None]:
# gradient of lensing
@btime CUDA.@sync gradient(ϕ -> norm(L(ϕ) * f), ϕ);