In [339]:
using ReactiveMP
using StableRNGs
using Flux
using Distributions

In [340]:
cvi = CVI(StableRNG(42), 1, 2000, Flux.Adam(0.05), ForwardDiffGrad(), 10, Val(true), true)

ProdCVI{StableRNGs.LehmerRNG, Flux.Optimise.Adam, ForwardDiffGrad, true}(StableRNGs.LehmerRNG(state=0x00000000000000000000000000000055), 1, 2000, Flux.Optimise.Adam(0.05, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}()), ForwardDiffGrad(), 10, Val{true}(), true)

In [341]:
f = (x) -> x

inbound_x = NormalMeanVariance(2, 10)
inbound_y = NormalMeanVariance(4, 10)

F = promote_variate_type(variate_form(inbound_x), ReactiveMP.AbstractContinuousGenericLogPdf)
f_likelihood = convert(F, ReactiveMP.UnspecifiedDomain(), (z) -> logpdf(inbound_y, f(z)))

ContinuousUnivariateLogPdf(ReactiveMP.UnspecifiedDomain())

In [342]:
# compute marginal on the x edge
q_x = prod(cvi, inbound_x, f_likelihood)

NormalWeightedMeanPrecision{Float64}(xi=0.6999949695037428, w=0.24999997499883103)

In [343]:
samples = rand(q_x, 2000)
inbound_params = ReactiveMP.naturalparams(inbound_y)
λ_current = ReactiveMP.naturalparams(inbound_y)
T = typeof(λ_current)

@info λ_current

for i in 1:cvi.n_iterations

    logq = let samples = samples, inbound = inbound_x, T = T
        (η) -> mean((sample) -> -logpdf(inbound, sample) * logpdf(ReactiveMP.as_naturalparams(T, η), f(sample)), samples)
    end

    ∇logq = ReactiveMP.compute_gradient(cvi.grad, logq, vec(λ_current))

    # @info ∇logq
    Fisher = ReactiveMP.compute_fisher_matrix(cvi, T, vec(λ_current))

    # compute natural gradient
    ∇f = Fisher \ ∇logq

    ∇ = λ_current - inbound_params - as_naturalparams(T, ∇f)

    λ_new = as_naturalparams(T, ReactiveMP.cvi_update!(cvi.opt, λ_current, ∇))

    # # check whether updated natural parameters are proper
    if ReactiveMP.isproper(λ_new) && ReactiveMP.enforce_proper_message(cvi.enforce_proper_messages, λ_new, ReactiveMP.naturalparams(inbound_y))
        λ_current = λ_new
        hasupdated = true
    end
end

@info  λ_current - ReactiveMP.naturalparams(inbound_y)

q_y = convert(Distribution, λ_current)
@info q_y
outbound_y = convert(Distribution, λ_current - ReactiveMP.naturalparams(inbound_y))

┌ Info: UnivariateNormalNaturalParameters{Float64}(0.4, -0.05)
└ @ Main /Users/mykola/repos/CIExpirements/demos/check_obvious.ipynb:6
┌ Info: UnivariateNormalNaturalParameters{Float64}(0.049999996990617857, -0.0499999899053406)
└ @ Main /Users/mykola/repos/CIExpirements/demos/check_obvious.ipynb:33


┌ Info: NormalWeightedMeanPrecision{Float64}(xi=0.4499999969906179, w=0.1999999798106812)
└ @ Main /Users/mykola/repos/CIExpirements/demos/check_obvious.ipynb:36


NormalWeightedMeanPrecision{Float64}(xi=0.049999996990617857, w=0.0999999798106812)

In [344]:
println(prod(ProdAnalytical(), inbound_x, inbound_y))
println(q_x)
println(q_y)

NormalWeightedMeanPrecision{Float64}(xi=0.6000000000000001, w=0.2)
NormalWeightedMeanPrecision{Float64}(xi=0.6999949695037428, w=0.24999997499883103)
NormalWeightedMeanPrecision{Float64}(xi=0.4499999969906179, w=0.1999999798106812)
