In [None]:
using BenchmarkTools
using PyCall
using PyPlot
using Setfield
using Random
using Statistics
using LinearAlgebra
using CuArrays
using Adapt
using Zygote
using Zygote: @adjoint

using Revise
using CMBLensing

In [None]:
py"""
import numpy as np
import tensorflow as tf
import began
"""

In [None]:
θpix = 20*60/256;

In [None]:
py"""
model = began.CVAE(256,5);
model.load_weights("dat/vae.h5")
"""

In [None]:
dust(θg) = cu(FlatMap(py"model.decode($(collect(θg'[:,:]))).numpy()"[1,:,:,1], θpix=θpix))

In [None]:
py"""
def grad_backp(v, z):
    v = v.astype(np.float32)
    z = z.astype(np.float32)
    assert v.ndim == 2
    assert v.shape[-1] == 1
    z = tf.constant(z)
    with tf.GradientTape() as tape:
        tape.watch(z)
        genned_image = model.decode(z)
        back = tf.tensordot(tf.reshape(genned_image, (1, 256 ** 2)), v, axes=[[1], [0]])
    return tape.gradient(back, z)
"""

@adjoint dust(θg) = dust(θg), Δ -> (cu(py"grad_backp($(Array(Map(Δ)[:])[:,:]), $(collect(θg'[:,:]))).numpy()"[:]),)

In [None]:
animate(map(1:10) do i
    dust(randn(Float32,256))
end, fps=5)

In [None]:
@unpack f, ds = load_sim_dataset(
    Nside = 256,
    θpix = θpix,
    pol = :I,
    storage = CuArray
);

In [None]:
function lnP(::Val{:dust}, f, θg, ds)
    
    @unpack Cf,Cn,d = ds
    
    Δ = d - f - dust(θg)
    
    -1/2f0 * (
        Δ' * pinv(Cn) * Δ +
        f' * pinv(Cf) * f +
        θg' * θg
    )
    
end

In [None]:
lnP(Val(:dust), f, θg₀, ds)

In [None]:
f = Map(f);

In [None]:
gradient((f,θg) -> lnP(Val(:dust), f, θg, ds), f, θg₀)

In [None]:
argmaxf_lnP(NoLensing(), ds, which=:sample) |> plot