# Agent

Ask Bert
\begin{align}
p(r_t,u_t,\omega_t,a_t | \omega_{t-1}) \propto \underbrace{p(r_t|u_t) p(u_t|\omega_t)}_{\text{likelihood}} \underbrace{p(\omega_t | \omega_{t-1}, a_t)}_{\text{state transition}}\underbrace{p(a_t)}_{\text{control}} \qquad (1)
\end{align}

\begin{align}
p(r_t,u_t,\omega_t | \omega_{t-1}) \propto p(r_t|\omega_t)p(\omega_t | \omega_{t-1}, u_t)p(u_t) \qquad (2)
\end{align}

\begin{align}
p(r_t,u_t,\omega_t \mid u_{t-1}) \propto \underbrace{p(r_t \mid u_t)}_{\text{response}} \underbrace{ p(u_t \mid u_{t-1}) p(u_t \mid \omega_t)}_{\text{perception }p(u_t \mid u_{t-1},\omega_t)} \underbrace{p(\omega_t)}_{\substack{\text{control} \\ \text{prior}}}  \qquad (3)
\end{align}

In [1]:
using ForneyLab
using ProgressMeter
using LinearAlgebra
using Plots

┌ Info: Precompiling ForneyLab [9fc3f58a-c2cc-5bff-9419-6a294fefdca9]
└ @ Base loading.jl:1278


In [87]:
# Model (1)
n_samples = 1
fg = FactorGraph()

# State prior
@RV ω_0 ~ GaussianMeanPrecision(placeholder(:m_ω_0), placeholder(:w_ω_0))


# Transition and observation model
ω = Vector{Variable}(undef, n_samples)
a = Vector{Variable}(undef, n_samples)
u = Vector{Variable}(undef, n_samples)
p = Vector{Variable}(undef, n_samples)
r = Vector{Variable}(undef, n_samples)

ω_i_min = ω_0
for i in 1:n_samples

    @RV a[i] ~ Gamma(3.0, 2.0)
    @RV ω[i] ~ GaussianMeanPrecision(ω_i_min, a[i])
    @RV u[i] ~ GaussianMeanPrecision(ω[i], 100.0)
    
    f(z) = 1/(1+ exp(-z))
    @RV p[i] ~ Nonlinear{Sampling}(u[i], g=f)
    @RV r[i] ~ Bernoulli(p[i])

    # Data placeholder
    placeholder(r[i], :r, index=i)

    # Reset state for next step
    ω_i_min = ω[i]
end
q = PosteriorFactorization([ω_0; ω], p, a, ids=[:Ω :P :A])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true)
println(source_code)

begin

function stepP!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 4))

messages[1] = ruleVBGaussianMeanPrecisionOut(nothing, marginals[:ω_1], ProbabilityDistribution(Univariate, PointMass, m=100.0))
messages[2] = ruleVBBernoulliIn1(ProbabilityDistribution(Univariate, PointMass, m=data[:r][1]), nothing)
messages[3] = ruleSPNonlinearSIn1MN(f, messages[2], nothing)
messages[4] = ruleSPNonlinearSOutNM(f, nothing, messages[1])

marginals[:p_1] = messages[4].dist * messages[2].dist
marginals[:u_1] = messages[1].dist * messages[3].dist

return marginals

end

function stepA!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 2))

messages[1] = ruleVBGammaOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=3.0), ProbabilityDistribution(Univariate, PointMass, m=2.0))
messages[2] = ruleSVBGaussianMeanPrecisionW(marginals[:ω_1_ω_0], nothing)

marginals[:a_1] = messages[1].dist * messages[2].dist

return marginals


In [111]:
# Model (2)
n_samples = 1
fg = FactorGraph()

# State prior
@RV ω_0 ~ GaussianMeanPrecision(placeholder(:m_ω_0), placeholder(:w_ω_0))

# Transition and observation model
ω = Vector{Variable}(undef, n_samples)
u = Vector{Variable}(undef, n_samples)
p = Vector{Variable}(undef, n_samples)
r = Vector{Variable}(undef, n_samples)

ω_i_min = ω_0
for i in 1:n_samples

    @RV u[i] ~ Gamma(3, 2)
    
    f(x) = 1/(1+ exp(-x))

    @RV ω[i] ~ GaussianMeanPrecision(ω_i_min, u[i])

    @RV p[i] ~ Nonlinear{Sampling}(ω[i], g=f)
    
    @RV r[i] ~ Bernoulli(p[i])

    # Data placeholder
    placeholder(r[i], :r, index=i)

    # Reset state for next step
    ω_i_min = ω[i]
end
draw()
q = PosteriorFactorization(ω_0, ω, u, ids=[:Ω0 :Ω :U])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true)
println(source_code)

begin

function stepU!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 2))

messages[1] = ruleVBGammaOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=3), ProbabilityDistribution(Univariate, PointMass, m=2))
messages[2] = ruleVBGaussianMeanPrecisionW(marginals[:ω_1], marginals[:ω_0], nothing)

marginals[:u_1] = messages[1].dist * messages[2].dist

return marginals

end

function stepΩ!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 4))

messages[1] = ruleVBGaussianMeanPrecisionOut(nothing, marginals[:ω_0], marginals[:u_1])
messages[2] = ruleVBBernoulliIn1(ProbabilityDistribution(Univariate, PointMass, m=data[:r][1]), nothing)
messages[3] = ruleSPNonlinearSIn1MN(f, messages[2], nothing)
messages[4] = ruleSPNonlinearSOutNM(f, nothing, messages[1])

marginals[:p_1] = messages[4].dist * messages[2].dist
marginals[:ω_1] = messages[1].dist * messages[3].dist

return marginals

end

function stepΩ0!(data::Dic

In [152]:
# Model (3) fails due to multiple inputs of different distributions
n_samples = 1
fg = FactorGraph()

# State prior

@RV u_0 ~ GaussianMeanPrecision(placeholder(:m_u_0), placeholder(:w_u_0))

# Transition and observation model
ω = Vector{Variable}(undef, n_samples)
p = Vector{Variable}(undef, n_samples)
u = Vector{Variable}(undef, n_samples)
r = Vector{Variable}(undef, n_samples)

u_i_min = u_0
for i in 1:n_samples

    @RV ω[i] ~ Beta(1.0, 1.0)
    
    @RV u[i] ~ GaussianMeanPrecision(u_i_min, 100.0)
    
    sigmoid(x, z) = 1/(1+ exp(-(x+z)))
    @RV p[i] ~ Nonlinear{Sampling}(u[i], ω[i], g=sigmoid)
    @RV r[i] ~ Bernoulli(p[i])

    # Data placeholder
    placeholder(r[i], :r, index=i)

    # Reset state for next step
    u_i_min = u[i]
end
draw()
q = PosteriorFactorization(ω, u, ids=[:Ω :U])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true)

LoadError: No applicable SumProductRule{Nonlinear{Sampling}} update for Nonlinear{Sampling} node with inbound types: Message{Beta,var_type} where var_type<:ForneyLab.VariateType, Message{GaussianMeanPrecision,var_type} where var_type<:ForneyLab.VariateType, Nothing

In [154]:
# Model (3) hack around previous error
n_samples = 1
fg = FactorGraph()

# State prior
@RV u_0 ~ GaussianMeanPrecision(placeholder(:m_u_0), placeholder(:w_u_0))

# Transition and observation model
ω = Vector{Variable}(undef, n_samples)
ωsample = Vector{Variable}(undef, n_samples)
ωn = Vector{Variable}(undef, n_samples)
p = Vector{Variable}(undef, n_samples)
u = Vector{Variable}(undef, n_samples)
un = Vector{Variable}(undef, n_samples)
utr = Vector{Variable}(undef, n_samples)
r = Vector{Variable}(undef, n_samples)

u_i_min = u_0
for i in 1:n_samples

    @RV ω[i] ~ Beta(1.0, 1.0)
    βN(x) = x
    @RV ωsample[i] ~ Nonlinear{Sampling}(ω[i], g=βN)
    @RV ωn[i] ~ GaussianMeanPrecision(ωsample[i], 100.0)

    @RV utr[i] ~ GaussianMeanPrecision(u_i_min, 10.0)
    sumβN(x, y) = x + y
    @RV u[i] ~ Nonlinear{Sampling}(ωn[i], utr[i], g=sumβN)
    @RV un[i] ~ GaussianMeanPrecision(u[i], 100.0)
    sigmoid(x) = 1/(1+ exp(-(x)))
    @RV p[i] ~ Nonlinear{Sampling}(un[i], g=sigmoid)
    @RV r[i] ~ GaussianMeanPrecision(p[i], 1.0)

    # Data placeholder
    placeholder(r[i], :r, index=i)

    # Reset state for next step
    u_i_min = un[i]
end
q = PosteriorFactorization(ω, ωn, [u_0; un], ids=[:Ω :ΩN :U])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true)
println(source_code)

begin

function stepU!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 6))

messages[1] = ruleVBGaussianMeanPrecisionOut(nothing, marginals[:u_1], ProbabilityDistribution(Univariate, PointMass, m=100.0))
messages[2] = ruleVBGaussianMeanPrecisionM(ProbabilityDistribution(Univariate, PointMass, m=data[:r][1]), nothing, ProbabilityDistribution(Univariate, PointMass, m=1.0))
messages[3] = ruleSPNonlinearSIn1MN(sigmoid, messages[2], nothing)
messages[4] = ruleVBGaussianMeanPrecisionOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=data[:m_u_0]), ProbabilityDistribution(Univariate, PointMass, m=data[:w_u_0]))
messages[5] = ruleVBGaussianMeanPrecisionM(marginals[:utr_1], nothing, ProbabilityDistribution(Univariate, PointMass, m=10.0))
messages[6] = ruleSPNonlinearSOutNM(sigmoid, nothing, messages[1])

marginals[:p_1] = messages[6].dist * messages[2].dist
marginals[:u_0] = messages[4].dist * messages[5].dist
marginals[:un_1] = messages[1].dist * 