In [34]:
# generate data
import Distributions
using Random

Random.seed!(42)

n_samples = 100
dimensionality = 2
order = dimensionality
rθ = [-0.5, 0.2]
inputs = randn(n_samples, dimensionality)
σ(w, x) = 1/(1+exp(-w'x))
πs = [σ(rθ, inputs[i, :]) for i in 1:n_samples]
outputs = [rand(Distributions.Bernoulli(πs[i])) for i in 1:n_samples]

100-element Vector{Bool}:
 0
 1
 1
 0
 0
 1
 1
 1
 1
 1
 1
 1
 1
 ⋮
 0
 0
 0
 0
 1
 0
 1
 0
 0
 1
 1
 0

In [2]:
# test nonlinear
using Revise
using ForneyLab
import ForneyLab: unsafeMean, unsafeCov

graph = FactorGraph()

T = n_samples
x = Vector{Variable}(undef, T)
y = Vector{Variable}(undef, T)

# Hearing aid parameters
@RV θ  ~ GaussianMeanPrecision(placeholder(:m_θ, dims=(order,)), placeholder(:W_θ, dims=(order, order)))
f(w,x) = 1/(1+exp(-w'x))
for i in 1:T
    @eval $(Symbol("func$i"))(θ) = f(θ,inputs[$i, :])
    @RV x[i] ~ Nonlinear{Sampling}(θ, g=eval(Symbol("func$i")), in_variates=[Multivariate], out_variate=Univariate)
    @RV y[i] ~ Bernoulli(x[i])
    placeholder(y[i], :y, index=i)
end

┌ Info: Precompiling ForneyLab [9fc3f58a-c2cc-5bff-9419-6a294fefdca9]
└ @ Base loading.jl:1317


In [9]:
algo = messagePassingAlgorithm(θ, free_energy=true)
src_code = algorithmSourceCode(algo, free_energy=true);

In [10]:
println(src_code);

begin

function step!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 58))

messages[1] = ruleSPGaussianMeanPrecisionOutNPP(nothing, Message(Multivariate, PointMass, m=data[:m_θ]), Message(MatrixVariate, PointMass, m=data[:W_θ]))
messages[2] = ruleSPBernoulliIn1PN(Message(Univariate, PointMass, m=data[:y][10]), nothing)
messages[3] = ruleSPNonlinearSIn1MN(func10, messages[2], nothing, variate=Multivariate)
messages[4] = ruleSPBernoulliIn1PN(Message(Univariate, PointMass, m=data[:y][9]), nothing)
messages[5] = ruleSPNonlinearSIn1MN(func9, messages[4], nothing, variate=Multivariate)
messages[6] = ruleSPEqualityFn(messages[5], nothing, messages[3])
messages[7] = ruleSPBernoulliIn1PN(Message(Univariate, PointMass, m=data[:y][8]), nothing)
messages[8] = ruleSPNonlinearSIn1MN(func8, messages[7], nothing, variate=Multivariate)
messages[9] = ruleSPEqualityFn(messages[8], nothing, messages[6])
messages[10] = ruleSPBernoulliIn1PN(Message(Univariate, PointMass,

In [11]:
# Load algorithm
eval(Meta.parse(src_code))

freeEnergy (generic function with 1 method)

In [12]:
data = Dict(:y => outputs, :m_θ => zeros(dimensionality), :W_θ => 0.1*diageye(dimensionality))
marginals = step!(data)

Dict{Any, Any} with 11 entries:
  :x_3  => SampleList(s=[0.62, 0.57, 0.57, 0.58, 0.63, 0.51, 0.54, 0.53, 0.51, …
  :x_10 => SampleList(s=[0.02, 1.22e-03, 9.13e-03, 0.09, 0.01, 3.69e-03, 0.03, …
  :x_2  => SampleList(s=[0.76, 0.67, 0.77, 0.64, 0.76, 0.63, 0.79, 0.70, 0.81, …
  :x_5  => SampleList(s=[0.14, 0.92, 5.34e-03, 0.11, 0.14, 4.05e-03, 0.35, 0.67…
  :x_1  => SampleList(s=[0.72, 0.79, 0.68, 0.75, 0.86, 0.94, 0.66, 0.92, 0.50, …
  :x_4  => SampleList(s=[0.85, 0.94, 0.90, 0.48, 0.72, 0.80, 0.88, 0.69, 0.83, …
  :x_9  => SampleList(s=[1.00, 1.00, 1.00, 0.04, 1.00, 1.00, 0.97, 5.02e-03, 1.…
  :θ    => 𝒩(m=[-1.33, 0.81], w=[[1.08, -0.08][-0.08, 1.36]])…
  :x_6  => SampleList(s=[0.24, 0.90, 0.22, 0.91, 0.12, 0.40, 0.03, 0.75, 0.18, …
  :x_7  => SampleList(s=[0.73, 0.66, 0.82, 0.64, 0.85, 0.75, 0.75, 0.71, 0.52, …
  :x_8  => SampleList(s=[0.18, 0.44, 0.43, 0.46, 0.07, 0.89, 0.49, 0.89, 0.34, …

In [13]:
meθ = unsafeMean(marginals[:θ])

2-element Vector{Float64}:
 -1.3319520146225503
  0.8149393727598759

In [14]:
weθ = unsafeCov(marginals[:θ])

2×2 Matrix{Float64}:
 0.929187   0.0559767
 0.0559767  0.737735

In [15]:
println("training errors = $(sum(abs.([round(f(meθ, inputs[i, :])) - round(πs[i]) for i in 1:n_samples])))")

training errors = 0.0


In [58]:
# test nonlinear
using Revise
using ForneyLab
using LinearAlgebra
import ForneyLab: unsafeMean, unsafeCov

graph = FactorGraph()

T = 100
x = Vector{Variable}(undef, T)
z = Vector{Variable}(undef, T)
y = Vector{Variable}(undef, T)

# Hearing aid parameters
@RV θ  ~ GaussianMeanPrecision(placeholder(:m_θ, dims=(dimensionality,)), placeholder(:W_θ, dims=(dimensionality, dimensionality)))
f(w, x) = 1/(1+exp(-w'x))
for i in 1:T
    @RV z[i] ~ GaussianMeanPrecision(inputs[i, :], diageye(dimensionality))
    @RV x[i] ~ Nonlinear{Sampling}(θ, z[i], g=f, in_variates=[Multivariate, Multivariate], out_variate=Univariate)
    @RV y[i] ~ Bernoulli(x[i])
    placeholder(y[i], :y, index=i)
end

In [57]:
draw();

In [59]:
# Define posterior factorization
pfz = PosteriorFactorization();

In [60]:
# Compile algorithm
algo = messagePassingAlgorithm(free_energy=true)

# Generate source code
src_code = algorithmSourceCode(algo, free_energy=true);

In [61]:
# println(src_code)

In [62]:
# Load algorithm
eval(Meta.parse(src_code))

freeEnergy (generic function with 1 method)

In [63]:
messages = init();
n_messages = length(messages)

798

In [64]:
ruleSPGaussianMeanPrecisionOutNPP(nothing, Message(Multivariate, PointMass, m=[0.15614346264074028, -1.590579974922555]), Message(MatrixVariate, PointMass, m=Diagonal([10000.0, 10000.0])))

Message: 𝒩(m=[0.16, -1.59], w=diag[1.00e+04, 1.00e+04])


In [65]:
function init()
    global n_messages
    messages = Array{Message}(undef, n_messages)
    for i in 1:length(messages)
        messages[i] = Message(vague(GaussianMeanPrecision, 2))
    end
    messages
end

init (generic function with 1 method)

In [None]:
data = Dict(:y => outputs, :m_θ => zeros(dimensionality), :W_θ => 0.1*diageye(dimensionality))
messages = init()
marginals = Dict()
step!(data, marginals, messages)

In [45]:
unsafeMean(marginals[:θ])

2-element Vector{Float64}:
 -0.6590887856212173
  1.0557047930386412

In [46]:
weθ = unsafeCov(marginals[:θ])

2×2 Matrix{Float64}:
  1.31983   -0.010104
 -0.010104   0.7657

In [47]:
println("training errors = $(sum(abs.([round(f(meθ, inputs[i, :])) - round(πs[i]) for i in 1:n_samples])))")

training errors = 5.0


In [52]:
unsafeMean(marginals[:z_2])

2-element Vector{Float64}:
 -0.4444077176392397
 -0.025615236621922645

In [51]:
inputs[2, :]

2-element Vector{Float64}:
 -0.444383357109696
 -0.02566249380406308