In [1]:
using Revise
using Rocket
using ReactiveMP
using GraphPPL
using Distributions
using LinearAlgebra
import ProgressMeter
using WAV
using Plots

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1317
[33m[1m│ [22m[39m- If you have ReactiveMP checked out for development and have
[33m[1m│ [22m[39m  added LineSearches as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with ReactiveMP


In [5]:
# clean speech
cl, fs = wavread("speech/clean.wav")
σ² = 0.0001
# white noise
wn = sqrt(σ²)*randn(length(cl))
# noised speech
ns = cl .+ wn
wavwrite(ns, "speech/noised.wav")

In [6]:
# dividing into 10ms frames with 2.5ms overlap
start = 1
l = Int(round(0.01*fs))
overlap = Int(round(0.0025*fs))
totseg = Int(ceil(length(ns)/(l-overlap)))
segment = zeros(totseg, l)
zseg = zeros(totseg, l)
for i in 1:totseg - 1
    global start
    segment[i,1:l]=ns[start:start+l-1]
    zseg[i, 1:l] = cl[start:start+l-1]
    start = (l-overlap)*i+1
end
segment[totseg, 1:length(ns)-start+1] = ns[start:length(ns)]
zseg[totseg, 1:length(cl)-start+1] = cl[start:length(cl)];

In [7]:
function ar_ssm(series, order)
    inputs = [reverse!(series[1:order])]
    outputs = [series[order + 1]]
    for x in series[order+2:end]
        push!(inputs, vcat(outputs[end], inputs[end])[1:end-1])
        push!(outputs, x)
    end
    return inputs, outputs
end

ar_ssm (generic function with 1 method)

In [90]:
@model function lar_model(n, order, artype, c)

    x = randomvar(n)
    y = datavar(Float64, n)

    γ ~ GammaShapeRate(1.0, 1.0) where {q=MeanField()}
    θ ~ MvNormalMeanPrecision(randn(order), Matrix{Float64}(I, order, order)) where {q=MeanField()}
    x0 ~ MvNormalMeanPrecision(100.0 * ones(order), Matrix{Float64}(I, order, order)) where {q=MeanField()}

    x_prev = x0

    ct  = constvar(c)
    γ_y = constvar(1.0)

    ar_nodes = Vector{FactorNode}(undef, n)

    for i in 1:n
        ar_nodes[i], x[i] ~ AR(x_prev, θ, γ) where { q = q(y, x)q(γ)q(θ), meta = ARMeta(artype, order, ARsafe) }

        y[i] ~ NormalMeanPrecision(dot(ct, x[i]), γ_y) where {q=MeanField()}

        x_prev = x[i]
    end

    return x, y, θ, γ, ar_nodes
end

lar_model (generic function with 1 method)

In [106]:
function inference_lar(data, order, niter)
    n = length(data)
    artype = Multivariate
    c = zeros(order); c[1] = 1.0
    model, (x, y, θ, γ, ar_nodes) = lar_model(n, order, artype, c)

    γ_buffer = nothing
    θ_buffer = nothing
    x_buffer = Vector{Marginal}(undef, n)
    fe = Vector{Float64}()

    γsub = subscribe!(getmarginal(γ), (mγ) -> γ_buffer = mγ)
    θsub = subscribe!(getmarginal(θ), (mθ) -> θ_buffer = mθ)
    xsub = subscribe!(getmarginals(x), (mx) -> copyto!(x_buffer, mx))
    fesub = subscribe!(score(Float64, BetheFreeEnergy(), model), (f) -> push!(fe, f))

    setmarginal!(γ, GammaShapeRate(1.0, 1.0))
    setmarginal!(θ, MvNormalMeanPrecision(zeros(order), Matrix{Float64}(I, order, order)))

#     for i in 1:n
#         setmarginal!(ar_nodes[i], :y_x, MvNormalMeanPrecision(100.0 * ones(2*order), Matrix{Float64}(I, 2*order, 2*order)))
#     end

    ProgressMeter.@showprogress for i in 1:niter
        update!(y, data)
    end

    unsubscribe!(γsub)
    unsubscribe!(θsub)
    unsubscribe!(xsub)
    unsubscribe!(fesub)

    return γ_buffer, θ_buffer, x_buffer, fe
end

inference_lar (generic function with 1 method)

In [95]:
@model function gaussian_model(n)

    y = datavar(Float64, n)

    γ ~ GammaShapeRate(1.0, 1.0) where {q=MeanField()}
    x ~ NormalMeanPrecision(0.0, 1.0) where {q=MeanField()}

    for i in 1:n
        y[i] ~ NormalMeanPrecision(x, γ) where {q=MeanField()}
    end

    return y, x, γ
end

gaussian_model (generic function with 2 methods)

In [96]:
function inference_gaussian(outputs, niter)
    n = length(outputs)
    model, (y, x, γ) = gaussian_model(n, options = (limit_stack_depth = 500, ))

    γ_buffer = nothing
    x_buffer = nothing
    fe = Vector{Float64}()

    γsub = subscribe!(getmarginal(γ), (my) -> γ_buffer = my)
    xsub = subscribe!(getmarginal(x), (mx) -> x_buffer = mx)
    fesub = subscribe!(score(Float64, BetheFreeEnergy(), model), (f) -> push!(fe, f))

    setmarginal!(γ, GammaShapeRate(1.0, 1.0))

    for i in 1:niter
        update!(y, outputs)
    end

    unsubscribe!(γsub)
    unsubscribe!(xsub)
    unsubscribe!(fesub)

    return x_buffer, γ_buffer, fe
end

inference_gaussian (generic function with 1 method)

In [97]:
ar_order = 10
vmp_iter = 10
fe_ar = zeros(totseg, vmp_iter)
fe_gaussian = zeros(totseg, vmp_iter);

In [98]:
ProgressMeter.@showprogress for segnum in 1:totseg
    inputs, outputs = ar_ssm(zseg[segnum, :], ar_order)
    γ, θ, x, fe = inference_lar(outputs, ar_order, vmp_iter)
    mθ, vθ = mean(θ), cov(θ)
    mγ = mean(γ)
    fe_ar[segnum, :] = fe
    
    x, γ, fe = inference_gaussian(outputs, vmp_iter)
    mx, vx = mean(x), cov(x)
    mγ = mean(γ)
    fe_gaussian[segnum, :] = fe
end

LoadError: MethodError: no method matching iterate(::Rocket.CollectLatestObservable{Marginal, Vector{ProxyObservable{Marginal, ReactiveMP.MarginalObservable, Rocket.FilterProxy{ReactiveMP.var"#30#31"}}}, Vector{Marginal}, typeof(copy)})
[0mClosest candidates are:
[0m  iterate([91m::Union{LinRange, StepRangeLen}[39m) at range.jl:664
[0m  iterate([91m::Union{LinRange, StepRangeLen}[39m, [91m::Int64[39m) at range.jl:664
[0m  iterate([91m::T[39m) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
[0m  ...

In [108]:
inference_lar(outputs, ar_order, vmp_iter)

LoadError: MethodError: no method matching ar_y_x_marginal(::Type{ARsafe}, ::MvNormalWeightedMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}, ::MvNormalMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}, ::MvNormalMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}, ::GammaShapeRate{Float64}, ::ARMeta{Multivariate, DataType})
[0mClosest candidates are:
[0m  ar_y_x_marginal([91m::ARsafe[39m, ::Union{GaussianDistributionsFamily{T}, NormalDistributionsFamily{T}} where T, ::Union{GaussianDistributionsFamily{T}, NormalDistributionsFamily{T}} where T, ::Union{GaussianDistributionsFamily{T}, NormalDistributionsFamily{T}} where T, ::GammaShapeRate, ::ARMeta) at /Users/apodusenko/.julia/dev/ReactiveMP/src/rules/autoregressive/marginals.jl:7
[0m  ar_y_x_marginal([91m::ARunsafe[39m, ::Union{GaussianDistributionsFamily{T}, NormalDistributionsFamily{T}} where T, ::Union{GaussianDistributionsFamily{T}, NormalDistributionsFamily{T}} where T, ::Union{GaussianDistributionsFamily{T}, NormalDistributionsFamily{T}} where T, ::GammaShapeRate, ::ARMeta) at /Users/apodusenko/.julia/dev/ReactiveMP/src/rules/autoregressive/marginals.jl:23

In [100]:
outputs

70-element Vector{Float64}:
 -6.103701895199438e-5
 -0.00015259254737998596
 -0.0003662221137119663
  0.0
  0.0005798516800439467
  0.0004272591326639607
  0.00024414807580797754
  0.0003967406231879635
  0.00024414807580797754
  6.103701895199438e-5
 -0.0003662221137119663
 -0.0005188146610919523
  6.103701895199438e-5
  ⋮
 -6.103701895199438e-5
 -3.051850947599719e-5
 -0.00024414807580797754
 -0.0005493331705679495
 -0.0004882961516159551
 -0.00027466658528397473
 -0.00012207403790398877
 -9.155552842799158e-5
 -9.155552842799158e-5
 -0.00015259254737998596
 -0.00027466658528397473
 -3.051850947599719e-5

In [87]:
minimum(fe_ar[1:end-1, :])

-27.809869504666494

In [89]:
minimum(fe_gaussian)

-23.897701140779446