In [7]:
ones(U*X) |> Dirichlet

Dirichlet{Float64, Vector{Float64}, Float64}(alpha=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

In [6]:
ones(U*X) |> Dirichlet

Dirichlet{Float64, Vector{Float64}, Float64}(alpha=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

In [None]:
using RxInfer
using TOML
using VBMC
using Distributions
using Random
using StatsBase
using LogarithmicNumbers
using CSV
using DataFrames

Random.seed!(1234)
if "test" in readdir()
    cd("test")
end

function elbo(mcx, mcu)
    "Assumes P1, Pt, Pe as global vars."
    elbo = 0.
    elbo += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]/(mcx.P1[x] * mcu.P1[u])) * mcx.P1[x] * mcu.P1[u]
    ) |> sum
    elbo2 = 0.
    for t in 2:T
        elbo2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> 
            begin
                qtu = VBMC.quuprev(u, uprev, mcx, t, hpmm.Y[t], Pt, Pe)
                qtx = VBMC.qxxprev(x, xprev, mcu, t, hpmm.Y[t], Pt, Pe)
                log(
                    Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]/(qtu * qtx)
                ) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
            end
        ) |> sum
    end
    elbo + elbo2
end

struct EmissionNode{T <: Real} <: ContinuousUnivariateDistribution
    y :: T
    wt :: T
end

struct MetaTransition
    mat :: Matrix{Float64}
end

struct HmmTransition{T <: Real} <: DiscreteMultivariateDistribution
    wpast :: AbstractArray{T}
    wt :: AbstractArray{T}
end

vars = TOML.parsefile("Constants.toml")
T, U, X, Y = vars["T"], vars["U"], vars["X"], vars["Y"]
function emissionf(randVal, u, x)
    randVal .+ x
end
function revf(emission, u, x)
    emission .- x
end #the dot vectorizes, could use X instead
emission = Emission(emissionf, revf)

function create_model()
    P1 = begin
        ones(U*X) |>
        Dirichlet |>
        rand |>
        (x -> Categorical(x)) |>
        (d -> ReshapedCategorical(d, U, X))
    end
    Pt = reshape(rand(ones(U*X) |> Dirichlet, U * X), (U, X, U, X)) |> TransitionDistribution
    Pe = EmissionDistribution{Continuous}(Normal(0, 1), emission, U, X)
    dist = HpmmDistribution(P1, Pt, Pe)
    hpmm = rand(dist, T)
    dist, hpmm
end

function find_viterbiu(posteriors_w)
    posteriors_w .|> dist -> (sum(reshape(dist.p, U, X), dims=2) |> Iterators.flatten |> collect |> argmax)
end
# vmpus = find_viterbiu(a.posteriors[:w][20]);
function get_digits(val, base) digits(val, base=base, pad=T) .|> x -> x+1 end

function get_pdf(xs)
    total_prob = ULogarithmic(1.)
    for t in 1:T
        if t == 1
            p1 = 0.
            for u in 1:U
                p1 += P1[u,xs[t]] * Pe[u,xs[t],hpmm.Y[t]]
            end
            total_prob *= p1
        else
            pt = 0.
            for (u, uprev) in Iterators.product(1:U, 1:U)
                pt += Pt[u,xs[t],uprev,xs[t-1]] * Pe[u,xs[t],hpmm.Y[t]]
            end
            total_prob *= pt
        end
    end
    total_prob
end

function get_percentile(path)
    len = length(probabilities)
    pos = length(probabilities[probabilities .<= get_pdf(path)])
    pos/len
end

function get_x(w)
    VBMC.reshapeindex(w, U, X)[2]
end
function get_u(w)
    VBMC.reshapeindex(w, U, X)[1]
end

@node HmmTransition Stochastic [wt, wp]

@rule HmmTransition(:wp, Marginalisation) (q_wt :: Categorical, meta::MetaTransition) = begin
    G = q_wt.p
    A = meta.mat
    ηs = exp.(log.(A)' * G)
    νs = ηs ./ sum(ηs)
    return Categorical(νs...)
end
@rule HmmTransition(:wt, Marginalisation) (q_wp :: Categorical, meta::MetaTransition) = begin
    F = q_wp.p
    A = meta.mat #reshape(Pt.mat, U*X, U*X)
    ηs = exp.(log.(A) * F) # .* B # | clamp(⋅,tiny,one) | exp maybe or smth?
    νs = ηs ./ sum(ηs)
    return Categorical(νs...)
end
@marginalrule HmmTransition(:wt_wp) (q_wt::Categorical, q_wp::Categorical, meta::MetaTransition) = begin
    F, G = q_wp.p, q_wt.p
    AA = meta.mat
    ηs = exp.(log.(A) * F) # .* B
    ps = ηs .* G
    ps = ps ./ sum(ps)
    return (wt = Categorical(ps...), wp = q_wp)
end
@average_energy HmmTransition (q_wt::Categorical, q_wp::Categorical, meta::MetaTransition) = begin
    A = meta.mat
    G, F = q_wp.p, q_wt.p
    F' * log.(A) * G
end

@node EmissionNode Stochastic [y, wt]

@rule EmissionNode(:wt, Marginalisation) (q_y::PointMass, ) = begin 
    B = map(1:U*X) do w pdf(Normal(0,1), q_y.point-get_x(w)) end
    return Categorical(B./sum(B)...)
end
@rule EmissionNode(:y, Marginalisation) (q_wt :: Categorical, ) = begin
    B = map(1:U*X) do w pdf(Normal(0,1), y-get_x(w)) end
    G = q_wt.p
    return PointMass(exp(log.(B)' * G))
end

@marginalrule EmissionNode(:y_wt) (q_wt::Categorical, q_y::PointMass) = begin
    B = map(1:U*X) do w pdf(Normal(0,1), y-get_x(w)) end
    G = q_wt.p
    ps = log(B) .* G
    ps = ps./sum(ps)
    return (y = q_y, wt = Categorical(ps...)) 
end
@average_energy EmissionNode (q_y::PointMass, q_wt::Categorical) = begin
    B = map(1:U*X) do w pdf(Normal(0,1), q_y.point-get_x(w)) end
    F = q_wt.p
    F' * log.(B)
end

constraints = @constraints begin
    q(w) = q(w[begin])..q(w[end])
end

init = @initialization begin
    # Note T is hardcoded for now
    for t in 1:T
        q(w[t]) = vague(Categorical, U*X)
    end
end;

function infer_vmp(y, model)
    infer(
        model = model(),
        constraints = constraints,
        initialization = init,
        data = (y = y,),
        options = (limit_stack_depth = 500,),
        free_energy = true,
        showprogress=true,
        iterations = 20,    
        warn = false
    )
end
function find_viterbi(posteriors_w)
    posteriors_w .|> dist -> (sum(reshape(dist.p, U, X), dims=1) |> Iterators.flatten |> collect |> argmax)
end

function find_belief_prop_path()
    mcu = MarkovChain(U, T)
    mcx = MarkovChain(X, T)

    function norm(A::AbstractArray; p = 2)
        sum(abs.(A) .^ p)^(1 / p)
    end

    elboold = 0.
    for _ = 1:200
        tmpP1, tmpPt = mcx.P1 |> deepcopy, mcx.Pt |> deepcopy
        fillalphaX!(mcx, mcu, P1, Pt, Pe, hpmm.Y)
        fillbetaX!(mcx, mcu, Pt, Pe, hpmm.Y)
        VBMC.fillPtx!(mcx, mcu, Pt, Pe, hpmm.Y)

        fillalphaU!(mcu, mcx, P1, Pt, Pe, hpmm.Y)
        fillbetaU!(mcu, mcx, Pt, Pe, hpmm.Y)
        VBMC.fillPtu!(mcu, mcx, Pt, Pe, hpmm.Y)

        signeda = mcx.Pt .|> Logarithmic
        signedb = tmpPt .|> Logarithmic
        elbonew = elbo(mcx ,mcu)
        if (elbonew - elboold)/elboold |> abs < 1.e-17
            break
        end
        elboold = elbonew
    end
    VBMC.viterbi(mcx)
end

function find_all_probabilities()
    probabilities = zeros(0) .|> ULogFloat64
    maxprob = ULogFloat64(0.)
    best_path = zeros(T) .|> Int
    for x in 0:(X^T-1)
        xs = get_digits(x, X)
        total_prob = get_pdf(xs)
        if maxprob < total_prob
            maxprob = total_prob
            best_path .= xs
        end
        append!(probabilities, total_prob)
    end
    probabilities
end

vmp_percs, bp_percs, original_percs = zeros(0), zeros(0), zeros(0)
function iterate_append_percs!(vmp_percs :: Vector{<:Real}, bp_percs :: Vector{<:Real})
    @eval dist, hpmm = create_model()
    @eval P1,Pt,Pe = dist.P1, dist.Pt, dist.Pem
    @eval probabilities = find_all_probabilities()
    @eval p1 = P1.d.p

    # @eval @model function hidden_markov_model(y)
    #     w[1] ~ Categorical(p1)
    #     y[1] ~ EmissionNode(w[1])
    #     for t in 2:length(y)
    #         w[t] ~ HmmTransition(w[t-1]) where { meta = MetaTransition(reshape(Pt.mat, U*X, U*X)) }
    #         y[t] ~ EmissionNode(w[t])
    #     end
    # end

    # @eval vmp_perc = infer_vmp(hpmm.Y, hidden_markov_model) |> a -> a.posteriors[:w] |> last |> find_viterbi |> get_percentile
    # @eval bp_perc = find_belief_prop_path() |> get_percentile
    # @eval original_perc = hpmm.X |> get_percentile

    # append!(vmp_percs, vmp_perc), append!(bp_percs, bp_perc), append!(original_percs, original_perc)
end

for k in 1:915
    iterate_append_percs!(vmp_percs, bp_percs)
    if k % 50 == 0
        println(k/10, "%")
    end
end

df = DataFrame(BP = bp_percs, VMP = vmp_percs, ORIGINAL = original_percs)
CSV.write("percs.csv", df)

5.0%
10.0%
15.0%
20.0%
25.0%
30.0%
35.0%
40.0%
45.0%
50.0%
55.0%
60.0%
65.0%
70.0%
75.0%
80.0%
85.0%
90.0%


"percs.csv"

In [2]:
@eval dist, hpmm = create_model()
@eval P1,Pt,Pe = dist.P1, dist.Pt, dist.Pem
@eval probabilities = find_all_probabilities()
@eval p1 = P1.d.p

@eval @model function hidden_markov_model(y)
    w[1] ~ Categorical(p1)
    y[1] ~ EmissionNode(w[1])
    for t in 2:length(y)
        w[t] ~ HmmTransition(w[t-1]) where { meta = MetaTransition(reshape(Pt.mat, U*X, U*X)) }
        y[t] ~ EmissionNode(w[t])
    end
end

@eval vmp_perc = infer_vmp(hpmm.Y, hidden_markov_model) |> a -> a.posteriors[:w] |> last |> find_viterbi |> get_percentile
@eval bp_perc = find_belief_prop_path() |> get_percentile
@eval original_perc = hpmm.X |> get_percentile

append!(vmp_percs, vmp_perc), append!(bp_percs, bp_perc), append!(original_percs, original_perc)

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m


([1.0], [1.0], [0.154296875])

In [3]:
hpmm.Y

10-element Vector{Real}:
  0.9014312354221985
  1.0680640671579895
  1.7557602314803231
  1.3411823423306917
  2.2404649444967877
  0.5972328528374942
  1.72458946435425
  1.5238093121440581
 -1.3196143395031608
  0.3441977283566424

In [4]:
hpmm.X

10-element Vector{Int64}:
 2
 2
 2
 1
 1
 2
 1
 1
 2
 2