In [1]:
import Pkg; Pkg.activate("C:/Users/s151781/AppData/Local/Julia-1.3.1/GN/Project.toml")
using Revise
using Plots
using FFTW
using Compat
using WAV
using DSP
using Base64
using ForneyLab
using LinearAlgebra
using ProgressMeter

[32m[1mActivating[22m[39m environment at `C:\Users\s151781\AppData\Local\Julia-1.3.1\GN\Project.toml`


┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1273


In [27]:
include("../extensions/complex_gaussian.jl")
include("../extensions/hgf.jl")
include("../functions/auxiliary/workflow.jl")


em (generic function with 1 method)

In [62]:
# generate data
import Distributions: Normal, MvNormal, MixtureModel, Dirichlet
N = 20
nr_freqs = 5
n_samples = N

μ_ξi = 1.0*collect(2:2:2*nr_freqs)
Σ_ξ = 1e-4*diagm(ones(nr_freqs))
Σ_meas = 1e-10*diagm(ones(nr_freqs))

ξ_samples = Array{Array{Float64,1},1}(undef, N)
X_samples = Array{Array{Complex{Float64},1},1}(undef, N)
y_samples = Array{Array{Complex{Float64},1},1}(undef, N)

for n = 1:N

    ξ_samples[n] = Array{Float64,1}(undef, nr_freqs)
    X_samples[n] = Array{Complex{Float64},1}(undef, nr_freqs)
    y_samples[n] = Array{Complex{Float64},1}(undef, nr_freqs)
    
    for k = 1:nr_freqs
        sample_ξ = rand(Normal(μ_ξi[k], sqrt(Σ_ξ[k,k])))
        sample_X = rand(Normal(0, sqrt(0.5*exp(sample_ξ)))) + 1im*rand(Normal(0, sqrt(0.5*exp(sample_ξ))))
        sample_y = rand(Normal(real(sample_X), sqrt(0.5*Σ_meas[k,k]))) + 1im*rand(Normal(imag(sample_X), sqrt(0.5*Σ_meas[k,k])))

        ξ_samples[n][k] = sample_ξ
        X_samples[n][k] = sample_X
        y_samples[n][k] = sample_y
    end
end

t = collect(1:N)
Y = y_samples

20-element Array{Array{Complex{Float64},1},1}:
 [0.38227474749833124 - 0.2994120629878414im, -0.5211510984837671 - 2.7458178527840373im, 22.285753607476888 + 9.620560869102444im, -23.920333509856544 - 13.123105590286558im, 95.93763523892326 + 21.49492235755902im] 
 [-1.180073158358913 + 1.152920004977153im, 1.4776810579212096 + 0.5454808285352166im, -5.758232968006452 + 14.054960823105263im, 6.200300644050182 + 15.482244836037243im, -154.79345810464386 + 100.77210966127728im]  
 [-0.26598148478905165 + 0.16141897380406736im, -2.665510757927842 + 6.226734101665709im, -27.082953557964803 - 15.115385489961998im, 17.305232852684423 - 38.595408136030436im, -60.3549463265912 + 37.85481313574915im]
 [0.805278983438589 - 1.1952772899547033im, -0.49920398884737427 - 2.6984110026705586im, 33.06078780063678 + 13.224305689161547im, -91.43699642038655 + 19.117424635054018im, 5.149988372306367 + 27.312229438930576im]  
 [1.0293126692794015 + 0.08199990575803795im, 1.4578244123142219 + 0.732278562279

# Building graph

In [63]:
#nr_freqs = dimension

N_clusters = 3

fg = FactorGraph()
α = 0.5
# Specify generative model
@RV _pi ~ ForneyLab.Dirichlet(α*ones(N_clusters))
@RV m_1 ~ GaussianMeanVariance(zeros(nr_freqs), 100*diagm(ones(nr_freqs)))
@RV w_1 ~ Wishart(diagm(ones(nr_freqs)), 5.0)
@RV m_2 ~ GaussianMeanVariance(zeros(nr_freqs), 100*diagm(ones(nr_freqs)))
@RV w_2 ~ Wishart(diagm(ones(nr_freqs)), 5.0)
@RV m_3 ~ GaussianMeanVariance(zeros(nr_freqs), 100*diagm(ones(nr_freqs)))
@RV w_3 ~ Wishart(diagm(ones(nr_freqs)), 5.0)

z = Vector{Variable}(undef, n_samples)
ξ = Vector{Variable}(undef, n_samples)
X = Vector{Variable}(undef, n_samples)
y = Vector{Variable}(undef, n_samples)
for i in 1:n_samples
    @RV z[i] ~ Categorical(_pi)
    @RV ξ[i] ~ GaussianMixture(z[i], m_1, w_1, m_2, w_2, m_3, w_3)
    # HGF
    @RV X[i] ~ HGF(ξ[i])

    # observation model
    @RV y[i] ~ ComplexNormal(X[i], 1e-10*diagm(ones(nr_freqs)).+0im, mat(0.0+0.0im))
    
    
    placeholder(y[i], :y, index=i, dims=(nr_freqs,))
end
# draw graph
ForneyLab.draw(fg)

In [64]:
# Build the algorithm
q = PosteriorFactorization(_pi, m_1, w_1, m_2, w_2, m_3, w_3, z, X, ξ, ids=[:PI :M1 :W1 :M2 :W2 :M3 :W3 :Z :X :Ξ])
algo = variationalAlgorithm(q)

# Generate source code
source_code = algorithmSourceCode(algo);

# Load algorithm
eval(Meta.parse(source_code));

In [65]:
println(source_code)

begin

function stepZ!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 40))

messages[1] = ruleVBCategoricalOut(nothing, marginals[:_pi])
messages[2] = ruleVBGaussianMixtureZCat(marginals[:ξ_9], nothing, marginals[:m_1], marginals[:w_1], marginals[:m_2], marginals[:w_2], marginals[:m_3], marginals[:w_3])
messages[3] = ruleVBCategoricalOut(nothing, marginals[:_pi])
messages[4] = ruleVBGaussianMixtureZCat(marginals[:ξ_8], nothing, marginals[:m_1], marginals[:w_1], marginals[:m_2], marginals[:w_2], marginals[:m_3], marginals[:w_3])
messages[5] = ruleVBCategoricalOut(nothing, marginals[:_pi])
messages[6] = ruleVBGaussianMixtureZCat(marginals[:ξ_7], nothing, marginals[:m_1], marginals[:w_1], marginals[:m_2], marginals[:w_2], marginals[:m_3], marginals[:w_3])
messages[7] = ruleVBCategoricalOut(nothing, marginals[:_pi])
messages[8] = ruleVBGaussianMixtureZCat(marginals[:ξ_6], nothing, marginals[:m_1], marginals[:w_1], marginals[:m_2], marginals[:w_2], margi

In [66]:
data = Dict(:y => Y)

# Prepare posterior factors
marginals = Dict(:_pi => vague(ForneyLab.Dirichlet, N_clusters),
                 :m_1 => ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=-1.0*ones(nr_freqs), v=1e4*diagm(ones(nr_freqs))),
                 :w_1 => vague(Wishart, nr_freqs),
                 :m_2 => ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=ones(nr_freqs), v=1e4*diagm(ones(nr_freqs))),
                 :w_2 => vague(Wishart, nr_freqs),
                 :m_3 => ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=ones(nr_freqs), v=1e4*diagm(ones(nr_freqs))),
                 :w_3 => vague(Wishart, nr_freqs))
for i in 1:n_samples
    marginals[:z_*i] = vague(Categorical)
    marginals[:X_*i] = ProbabilityDistribution(Multivariate, ComplexNormal, μ=zeros(nr_freqs) .+ 0.0im, Γ=1e-10*diagm(ones(nr_freqs)).+0im, C=mat(0.0+0.0im));
    marginals[:ξ_*i] = ProbabilityDistribution(ForneyLab.Multivariate, GaussianMeanVariance, m=zeros(nr_freqs), v=diagm(ones(nr_freqs)))
end



In [67]:
# Execute algorithm
n_its = 10
@showprogress for i in 1:n_its
    stepX!(data, marginals)
    stepΞ!(data, marginals)
    stepZ!(data, marginals)
    stepPI!(data, marginals)
    stepM1!(data, marginals)
    stepW1!(data, marginals)
    stepM2!(data, marginals)
    stepW2!(data, marginals) 
    stepM3!(data, marginals)
    stepW3!(data, marginals)   
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:08[39m


In [68]:
marginals[:_pi]

Dir(a=[0.50, 10.50, 10.50])


In [69]:
marginals[:z_1]

Cat(p=[2.61e-46, 0.50, 0.50])


In [70]:
mean(marginals[:m_1])

5-element Array{Float64,1}:
 1.1265097591891952e-8
 2.7275177402306274e-8
 5.107393013084792e-8 
 5.604115186372582e-8 
 7.21829163526821e-8  

In [71]:
mean(marginals[:m_2])

5-element Array{Float64,1}:
 1.1257050672705078
 2.726601832536714 
 5.105088363891105 
 5.6029942202434   
 7.216140116421354 

In [72]:
mean(marginals[:m_3])

5-element Array{Float64,1}:
 1.1257050672705078
 2.726601832536714 
 5.105088363891105 
 5.6029942202434   
 7.216140116421354 

In [74]:
μ_ξi

5-element Array{Float64,1}:
  2.0
  4.0
  6.0
  8.0
 10.0