# Noise is set to 1. for testing purposes

In [None]:
using RxInfer
using TOML
using VBMC

using Distributions
using Random
using StatsBase
using LogarithmicNumbers
using LinearAlgebra
using Combinatorics: permutations

Random.seed!(1234)
if "test" in readdir()
    cd("test")
end

function elbo(mcx, mcu)
    "Assumes P1, Pt, Pe as global vars."
    elbo = 0.
    elbo += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]/(mcx.P1[x] * mcu.P1[u])) * mcx.P1[x] * mcu.P1[u]
    ) |> sum
    elbo2 = 0.
    for t in 2:T
        elbo2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> 
            begin
                qtu = VBMC.quuprev(u, uprev, mcx, t, hpmm.Y[t], Pt, Pe)
                qtx = VBMC.qxxprev(x, xprev, mcu, t, hpmm.Y[t], Pt, Pe)
                log(
                    Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]/(qtu * qtx)
                ) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
            end
        ) |> sum
    end
    elbo + elbo2
end

struct EmissionNode{T <: Real} <: ContinuousUnivariateDistribution
    y :: T
    wt :: T
end

struct MetaTransition
    mat :: Matrix{Float64}
    t :: Int
end

struct HmmTransition{T <: Real} <: DiscreteMultivariateDistribution
    wpast :: AbstractArray{T}
    wt :: AbstractArray{T}
end

In [17]:
vars = TOML.parsefile("Constants.toml")
T, U, X, Y = vars["T"], vars["U"], vars["X"], vars["Y"]
T = 100
noise_std = vars["noise_std"]
function emissionf(randVal, u, x)
    randVal .+ x
end
function revf(emission, u, x)
    emission .- x
end #the dot vectorizes, could use X instead
emission = Emission(emissionf, revf)

function create_model()
    P1 = begin
        ones(U*X) |>
        Dirichlet |>
        rand |>
        (x -> Categorical(x)) |>
        (d -> ReshapedCategorical(d, U, X))
    end
    Pt = reshape(rand(Dirichlet(ones(U*X)), U * X), (U, X, U, X)) |> TransitionDistribution
    Pe = EmissionDistribution{Continuous}(Normal(0, noise_std), emission, U, X)
    dist = HpmmDistribution(P1, Pt, Pe)
    hpmm = rand(dist, T)
    dist, hpmm
end

create_model (generic function with 1 method)

In [18]:
function find_viterbiu(posteriors_w)
    posteriors_w .|> dist -> (sum(reshape(dist.p, U, X), dims=2) |> Iterators.flatten |> collect |> argmax)
end
# vmpus = find_viterbiu(a.posteriors[:w][20]);
function get_digits(val, base) digits(val, base=base, pad=T) .|> x -> x+1 end

function get_percentile(path, probabilities)
    len = length(probabilities)
    pos = length(probabilities[probabilities .<= get_pdf(path)])
    pos/len
end

function get_pdf(xs)
    prevalphas = ones(U) .|> ULogarithmic
    for t in 1:T
        if t == 1
            p1 = 0.
            for u in 1:U
                prevalphas[u] = P1[u,xs[t]] * Pe[u,xs[t],hpmm.Y[t]]
            end
        else
            pts = Iterators.product(1:U, 1:U) .|> ((u,uprev),) -> Pt[u,xs[t],uprev,xs[t-1]] * Pe[u,xs[t],hpmm.Y[t]]
            prevalphas = pts' * prevalphas
        end
    end
    prevalphas |> sum
end

get_pdf (generic function with 1 method)

# Variational Message Passing setup

In [19]:
function get_x(w)
    VBMC.reshapeindex(w, U, X)[2]
end
function get_u(w)
    VBMC.reshapeindex(w, U, X)[1]
end

@node HmmTransition Stochastic [wt, wp]

@rule HmmTransition(:wp, Marginalisation) (q_wt :: Categorical, meta::MetaTransition) = begin
    G = q_wt.p
    A = meta.mat
    ηs = exp.(log.(A)' * G)
    νs = ηs ./ sum(ηs)
    return Categorical(νs...)
end
@rule HmmTransition(:wt, Marginalisation) (q_wp :: Categorical, meta::MetaTransition) = begin
    F = q_wp.p
    A = meta.mat #reshape(Pt.mat, U*X, U*X)
    ηs = exp.(log.(A) * F) # .* B # | clamp(⋅,tiny,one) | exp maybe or smth?
    νs = ηs ./ sum(ηs)
    return Categorical(νs...)
end
@marginalrule HmmTransition(:wt_wp) (q_wt::Categorical, q_wp::Categorical, meta::MetaTransition) = begin
    A = meta.mat
    #This is copied from ReactiveMP.jl transition/marginals.jl, however I think this marginalrule is never called
    B = Diagonal(probvec(q_wt)) * A * Diagonal(probvec(q_wp))
    P = map!(Base.Fix2(/, sum(B)), B, B) # inplace version of B ./ sum(B)
    return Contingency(P, Val(false))
    # F, G = q_wp.p, q_wt.p
    # ηs = exp.(log.(A) * F)
    # ps = ηs ./ sum(ηs)
    # # ps = ηs .* G
    # # ps = ps ./ sum(ps)
    # # return (wt = Categorical(ps...), wp = q_wp)
    # P = map!(Base.Fix2(/, sum(B)), B, B)

    # ηs2 = exp.(log.(A)' * G)
    # ps2 = ηs2 ./ sum(ηs2)
    # return G' .* (F .* A) |> Iterators.flatten |> collect
    # return (wt = Categorical(ps...), wp = Categorical(ps2...))
end
@average_energy HmmTransition (q_wt::Categorical, q_wp::Categorical, meta::MetaTransition) = begin
    A = meta.mat
    G, F = q_wp.p, q_wt.p
    F' * log.(A) * G
end

@node EmissionNode Stochastic [y, wt]

@rule EmissionNode(:wt, Marginalisation) (q_y::PointMass, ) = begin 
    B = map(1:U*X) do w pdf(Normal(0,noise_std), q_y.point-get_x(w)) end
    return Categorical(B./sum(B)...)
end
# @rule EmissionNode(:y, Marginalisation) (q_wt :: Categorical, ) = begin
#     B = map(1:U*X) do w pdf(Normal(0,noise_std), y-get_x(w)) end
#     G = q_wt.p
#     return PointMass(exp(log.(B)' * G))
# end

@marginalrule EmissionNode(:y_wt) (q_wt::Categorical, q_y::PointMass) = begin
    B = map(1:U*X) do w pdf(Normal(0,noise_std), y-get_x(w)) end
    G = q_wt.p
    ps = log(B) .* G
    ps = ps./sum(ps)
    return (y = q_y, wt = Categorical(ps...)) 
end
@average_energy EmissionNode (q_y::PointMass, q_wt::Categorical) = begin
    B = map(1:U*X) do w pdf(Normal(0,noise_std), q_y.point-get_x(w)) end
    F = q_wt.p
    F' * log.(B)
end

constraints = @constraints begin
    q(w) = q(w[begin])..q(w[end])
end

init = @initialization begin
    # Note T is hardcoded for now
    for t in 2:T
        q(w[t]) = vague(Categorical, U*X)
    end
end;

In [29]:
# p1 = P1.d.p

# @eval @model function hidden_markov_model(y)
#     w[1] ~ Categorical(p1)
#     y[1] ~ EmissionNode(w[1])
#     for t in 2:length(y)
#         w[t] ~ HmmTransition(w[t-1]) where { meta = MetaTransition(reshape(Pt.mat, U*X, U*X), t) }
#         y[t] ~ EmissionNode(w[t])
#     end
# end

# pathvmp = infer_vmp(hpmm.Y, hidden_markov_model) |> a -> a.posteriors[:w] |> last |> find_viterbi;
# # pathbp = find_belief_prop_path();

In [5]:
P1 = begin
    Vector(1:U*X) |>
    Dirichlet |>
    rand |>
    (x -> Categorical(x)) |>
    (d -> ReshapedCategorical(d, U, X))
end
Pt = reshape(rand(Dirichlet(1:U*X), U * X), (U, X, U, X)) |> TransitionDistribution
Pe = EmissionDistribution{Continuous}(Normal(0, noise_std), emission, U, X)
dist = HpmmDistribution(P1, Pt, Pe)
hpmm = rand(dist, T)
dist, hpmm

(HpmmDistribution(ReshapedCategorical(Categorical{Float64, Vector{Float64}}(
support: Base.OneTo(6)
p: [0.024394539459170042, 0.15338254375246516, 0.17617329801791798, 0.10970725111237985, 0.23675936086230961, 0.2995830067957574]
)
, 3, 2, nothing), TransitionDistribution([0.07757085772550631 0.10691600554255486; 0.0837559928742681 0.24878643665538883; 0.0933412046039046 0.3896295025983774;;; 0.13435598648589636 0.0625407873459112; 0.0875991075861795 0.1415939778226005; 0.25510102800233114 0.31880911275708135;;; 0.016794504404259784 0.0732044232648002; 0.19974813961256987 0.15271737311929992; 0.3472900755721375 0.21024548402693272;;;; 0.043591774755137265 0.19399151317488603; 0.14900786232232988 0.21358908677688454; 0.19972784145479283 0.2000919215159696;;; 0.017935345397800298 0.15447288238504406; 0.11593444808327888 0.24502198262269392; 0.14411455863909367 0.32252078287208924;;; 0.01699694042124585 0.22952807072491957; 0.17305280362262807 0.09886505269171775; 0.23863592162671304 0.24

In [30]:
function infer_vmp(y, model)
    infer(
        model = model(),
        constraints = constraints,
        initialization = init,
        data = (y = y,),
        options = (limit_stack_depth = 500,),
        free_energy = true,
        showprogress=true,
        iterations = 20,
        warn = false
    )
end
function find_viterbi(posteriors_w)
    posteriors_w .|> dist -> (sum(reshape(dist.p, U, X), dims=1) |> Iterators.flatten |> collect |> argmax)
end

find_viterbi (generic function with 1 method)

# Belief prop

In [22]:
function find_belief_prop_path()
    mcu = MarkovChain(U, T)
    mcx = MarkovChain(X, T)

    function norm(A::AbstractArray; p = 2)
        sum(abs.(A) .^ p)^(1 / p)
    end

    elboold = 0.
    for _ = 1:200
        tmpP1, tmpPt = mcx.P1 |> deepcopy, mcx.Pt |> deepcopy
        fillalphaX!(mcx, mcu, P1, Pt, Pe, hpmm.Y)
        fillbetaX!(mcx, mcu, Pt, Pe, hpmm.Y)
        VBMC.fillPtx!(mcx, mcu, Pt, Pe, hpmm.Y)

        fillalphaU!(mcu, mcx, P1, Pt, Pe, hpmm.Y)
        fillbetaU!(mcu, mcx, Pt, Pe, hpmm.Y)
        VBMC.fillPtu!(mcu, mcx, Pt, Pe, hpmm.Y)

        elbonew = elbo(mcx ,mcu)
        if (elbonew - elboold)/elboold |> abs < 1.e-17
            break
        end
        elboold = elbonew
    end
    VBMC.viterbi(mcx)
end

find_belief_prop_path (generic function with 1 method)

# Experiments

In [23]:
P1 = begin
    Vector(1:U*X) |>
    Dirichlet |>
    rand |>
    (x -> Categorical(x)) |>
    (d -> ReshapedCategorical(d, U, X))
end
Pt = reshape(rand(Dirichlet(1:U*X), U * X), (U, X, U, X)) |> TransitionDistribution
Pe = EmissionDistribution{Continuous}(Normal(0, noise_std), emission, U, X)
dist = HpmmDistribution(P1, Pt, Pe)
hpmm = rand(dist, T)
dist, hpmm

(HpmmDistribution(ReshapedCategorical(Categorical{Float64, Vector{Float64}}(
support: Base.OneTo(6)
p: [0.024394539459170042, 0.15338254375246516, 0.17617329801791798, 0.10970725111237985, 0.23675936086230961, 0.2995830067957574]
)
, 3, 2, nothing), TransitionDistribution([0.07757085772550631 0.10691600554255486; 0.0837559928742681 0.24878643665538883; 0.0933412046039046 0.3896295025983774;;; 0.13435598648589636 0.0625407873459112; 0.0875991075861795 0.1415939778226005; 0.25510102800233114 0.31880911275708135;;; 0.016794504404259784 0.0732044232648002; 0.19974813961256987 0.15271737311929992; 0.3472900755721375 0.21024548402693272;;;; 0.043591774755137265 0.19399151317488603; 0.14900786232232988 0.21358908677688454; 0.19972784145479283 0.2000919215159696;;; 0.017935345397800298 0.15447288238504406; 0.11593444808327888 0.24502198262269392; 0.14411455863909367 0.32252078287208924;;; 0.01699694042124585 0.22952807072491957; 0.17305280362262807 0.09886505269171775; 0.23863592162671304 0.24

In [26]:
p1 = P1.d.p

@eval @model function hidden_markov_model(y)
    w[1] ~ Categorical(p1)
    y[1] ~ EmissionNode(w[1])
    for t in 2:length(y)
        w[t] ~ HmmTransition(w[t-1]) where { meta = MetaTransition(reshape(Pt.mat, U*X, U*X), t) }
        y[t] ~ EmissionNode(w[t])
    end
end

pathvmp = infer_vmp(hpmm.Y, hidden_markov_model) |> a -> a.posteriors[:w] |> last |> find_viterbi
pathbp = find_belief_prop_path();

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m


# Results

In [33]:
# probabilities = zeros(0) .|> ULogFloat64
# maxprob = ULogFloat64(0.)
# best_path = zeros(T) .|> Int
# for x in 0:(X^T-1)
#     xs = get_digits(x, X)
#     total_prob = get_pdf(xs)
#     if maxprob < total_prob
#         maxprob = total_prob
#         best_path .= xs
#     end
#     append!(probabilities, total_prob)
# end
# @eval probabilities = probabilities;

In [44]:
xs = get_digits((probabilities |> argmax) - 1, X);
get_percentile(pathbp, probabilities)


1.0

In [42]:
a = best_path |> get_pdf, pathbp |> get_pdf, pathvmp |> get_pdf, naivepath |> get_pdf
a .|> float .|> b -> round(b,digits=18)

(3.36e-16, 3.36e-16, 3.36e-16, 2.0e-18)

In [40]:
function prob_arg_to_path(prob)
    (probabilities .== prob) |> argmax |> i -> get_digits(i-1, X)
end
function hamming(xs1,xs2) sum(xs1 .!= xs2) end
f = x -> abs(x-1) < abs(x-2) ? 1 : 2
function last_n(a, N) 
    a |> reverse |> a -> a[1:N] |> reverse
end
naivepath = hpmm.Y .|> f;

In [73]:
#probabilities |> sort |> l -> last_n(l, 20) .|> prob_arg_to_path .|> l -> hamming(best_path, l)

In [74]:
pathbp |> get_pdf, pathvmp |> get_pdf

(+exp(-213.21076843814436), +exp(-213.39226030862915))

In [82]:
3921225 / 1000 * 3.1 / 60 / 60

3.376610416666667

In [77]:
for _ in 1:1000
    get_pdf(pathbp)
end

In [100]:
combo = Combinatorics.combinations(1:100, 1) |> first

1-element Vector{Int64}:
 1

In [114]:
combos, vmp_probs, bp_probs, x_probs = zeros(0) |> Vector{Vector{Int}}, zeros(0), zeros(0), zeros(0)
change_x = x -> x == 1 ? 2 : 1
for combo in Combinatorics.combinations(1:100, 1)
    append!(combos, [combo])
    newpath_vmp,newpath_bp,newpath_x = zeros(100) .|> Int, zeros(100) .|> Int, zeros(100) .|> Int
    newpath_vmp .= pathvmp
    newpath_bp .= pathbp
    newpath_x .= hpmm.X
    newpath_vmp[combo] = newpath_vmp[combo] .|> change_x
    newpath_bp[combo] = newpath_bp[combo] .|> change_x
    newpath_x[combo] = newpath_x[combo] .|> change_x
    append!(vmp_probs, get_pdf(newpath_vmp))
    append!(bp_probs, get_pdf(newpath_bp))
    append!(x_probs, get_pdf(newpath_x))
end

for combo in Combinatorics.combinations(1:100, 2)
    append!(combos, [combo])
    newpath_vmp,newpath_bp,newpath_x = zeros(100) .|> Int, zeros(100) .|> Int, zeros(100) .|> Int
    newpath_vmp .= pathvmp
    newpath_bp .= pathbp
    newpath_x .= hpmm.X
    newpath_vmp[combo] = newpath_vmp[combo] .|> change_x
    newpath_bp[combo] = newpath_bp[combo] .|> change_x
    newpath_x[combo] = newpath_x[combo] .|> change_x
    append!(vmp_probs, get_pdf(newpath_vmp))
    append!(bp_probs, get_pdf(newpath_bp))
    append!(x_probs, get_pdf(newpath_x))
end

for combo in Combinatorics.combinations(1:100, 3)
    append!(combos, [combo])
    newpath_vmp,newpath_bp,newpath_x = zeros(100) .|> Int, zeros(100) .|> Int, zeros(100) .|> Int
    newpath_vmp .= pathvmp
    newpath_bp .= pathbp
    newpath_x .= hpmm.X
    newpath_vmp[combo] = newpath_vmp[combo] .|> change_x
    newpath_bp[combo] = newpath_bp[combo] .|> change_x
    newpath_x[combo] = newpath_x[combo] .|> change_x
    append!(vmp_probs, get_pdf(newpath_vmp))
    append!(bp_probs, get_pdf(newpath_bp))
    append!(x_probs, get_pdf(newpath_x))
end

In [None]:
get_percentile(pathvmp, all_neighborhoods), get_percentile(pathbp, all_neighborhoods), get_percentile(hpmm.X, all_neighborhoods)

In [119]:
vmp_probs |> maximum |> float,bp_probs |> maximum |> float,x_probs |> maximum |> float

(4.4297894663218935e-93, 4.548815314687853e-93, 1.4974498533789258e-100)

In [121]:
pathbp |> get_pdf |> float, pathvmp |> get_pdf |> float

(2.5336101193696097e-93, 2.1130942444017528e-93)

In [124]:
combos[vmp_probs |> argmax], combos[bp_probs |> argmax]

([50, 53, 61], [1, 4, 16])

In [141]:
# combos[5051] #1 2 3
onetwo_neighborhoods = vcat(vmp_probs[1:5050], bp_probs[1:5050], x_probs[1:5050], [pathvmp, pathbp, hpmm.X] .|> get_pdf)
all_neighborhoods = vcat(vmp_probs, bp_probs, x_probs, [pathvmp, pathbp, hpmm.X] .|> get_pdf)
get_percentile(pathvmp, all_neighborhoods), get_percentile(pathbp, all_neighborhoods), get_percentile(hpmm.X, all_neighborhoods)

(0.9970195081288867, 0.9987966089158885, 0.24849226291496504)

In [142]:
get_percentile(pathvmp, onetwo_neighborhoods), get_percentile(pathbp, onetwo_neighborhoods), get_percentile(hpmm.X, onetwo_neighborhoods)

(0.9848214874942256, 0.9937306143997888, 0.23579489210057414)