In [113]:
using RxInfer
using GraphViz

using Test
using TOML
using BenchmarkTools
using VBMC

using Distributions
using Random
using StatsBase
using LogarithmicNumbers

Random.seed!(123)
if "test" in readdir()
    cd("test")
end

In [114]:
function elbo(mcx, mcu)
    "Assumes P1, Pt, Pe as global vars."
    elbo = 0.
    elbo += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]/(mcx.P1[x] * mcu.P1[u])) * mcx.P1[x] * mcu.P1[u]
    ) |> sum
    elbo2 = 0.
    for t in 2:T
        elbo2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> 
            begin
                qtu = VBMC.quuprev(u, uprev, mcx, t, hpmm.Y[t], Pt, Pe)
                qtx = VBMC.qxxprev(x, xprev, mcu, t, hpmm.Y[t], Pt, Pe)
                log(
                    Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]/(qtu * qtx)
                ) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
            end
        ) |> sum
    end
    elbo + elbo2
end

elbo (generic function with 1 method)

In [329]:
vars = TOML.parsefile("Constants.toml")
T, U, X, Y = 2, vars["U"], vars["X"], vars["Y"]
# @test T == 2 && U == 3 && X == 2

P1 = begin
    Vector(1:U*X) |>
    Dirichlet |>
    rand |>
    (x -> Categorical(x)) |>
    (d -> ReshapedCategorical(d, U, X))
end

#P(·,· | u_prev, x_prev) = Pt[:,:,u_prev,x_prev] sums to 1 for fixed u_prev, x_prev
Pt = reshape(rand(Dirichlet(1:U*X), U * X), (U, X, U, X)) |> TransitionDistribution

function emissionf(randVal, u, x)
    randVal .+ x
end
function revf(emission, u, x)
    emission .- x
end #the dot vectorizes, could use X instead
emission = Emission(emissionf, revf)

Pe = EmissionDistribution{Continuous}(Normal(0, 0.1), emission, U, X)
dist = HpmmDistribution(P1, Pt, Pe)
hpmm = rand(dist, T)

TMM([1, 2], [2, 2], Real[1.8846442427557517, 1.9521787559521258])

In [330]:
function engT2(mcx, mcu)
    energy = 0.

    energy += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]) * mcx.P1[x] * mcu.P1[u]
    ) |>
    sum

    energy2 = 0.
    for t in 2:T
        energy2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> log(Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
        ) |>
        sum
    end
    energy += energy2
    # println("energytotal ", energy |> Float64)
    energy
end

engT2 (generic function with 1 method)

In [331]:
function eng(mcx, mcu)
    energy = 0.

    energy += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]) * mcx.P1[x] * mcu.P1[u]
    ) |>
    sum

    # println("Energy 1 ", energy |> Float64)

    energy2 = 0.
    for t in 2:T
        energy2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> log(Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
        ) |>
        sum
    end
    energy += energy2
    # println("Energy 2 ", energy2 |> Float64)
    # println("energytotal ", energy |> Float64)
    energy
end
function entx(mc, mcu)
    negent = 0.
    negent += 1:mc.Z .|>
    (
        z -> log(mc.P1[z]) * mc.P1[z]
    ) |>
    sum
    negent2 = 0.
    for t in 2:T
        negent2 += Iterators.product(1:mc.Z, 1:mc.Z) .|> (
            ((z,zprev),) -> 
            begin
                qt = VBMC.qxxprev(z, zprev, mcu, t, hpmm.Y[t], Pt, Pe)
                mc.alpha[zprev,t-1] * qt * log(qt) * mc.beta[z, t]
            end
        ) |>
        sum
    end
    negent += negent2
    -negent
end
function entu(mc, mcx)
    negent = 0.
    negent += 1:mc.Z .|>
    (
        z -> log(mc.P1[z]) * mc.P1[z]
    ) |>
    sum
    negent2 = 0.
    for t in 2:T
        negent2 += Iterators.product(1:mc.Z, 1:mc.Z) .|> (
            ((z,zprev),) -> 
            begin
                qt = VBMC.quuprev(z, zprev, mcx, t, hpmm.Y[t], Pt, Pe)
                mc.alpha[zprev,t-1] * qt * log(qt) * mc.beta[z, t]
            end
        ) |>
        sum
    end
    negent += negent2
    -negent
end
function elbo(mcx, mcu)
    elbo = 0.

    elbo += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]/(mcx.P1[x] * mcu.P1[u])) * mcx.P1[x] * mcu.P1[u]
    ) |>
    sum

    elbo2 = 0.
    for t in 2:T
        elbo2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> 
            begin
                qtu = VBMC.quuprev(u, uprev, mcx, t, hpmm.Y[t], Pt, Pe)
                qtx = VBMC.qxxprev(x, xprev, mcu, t, hpmm.Y[t], Pt, Pe)
                log(
                    Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]/(qtu * qtx)
                ) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
            end
        ) |>
        sum
    end
    elbo += elbo2
    elbo
end

elbo (generic function with 1 method)

In [332]:
function elbo2(mcx, mcu)
    elbo = 0.

    elbo += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]/(mcx.P1[x] * mcu.P1[u])) * mcx.P1[x] * mcu.P1[u]
    ) |>
    sum

    elbo2 = 0.
    for t in 2:T
        elbo2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> 
            begin
                qtu = VBMC.quuprev(u, uprev, mcx, t, hpmm.Y[t], Pt, Pe)
                qtx = VBMC.qxxprev(x, xprev, mcu, t, hpmm.Y[t], Pt, Pe)
                log(
                    Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]/(qtu * qtx)
                ) * mcx.alpha[xprev, t-1] * qtx * mcx.beta[x, t] * mcu.alpha[uprev, t-1] * qtu * mcu.beta[u, t]
            end
        ) |>
        sum
    end
    elbo += elbo2
    elbo
end

elbo2 (generic function with 1 method)

In [391]:
T = 2

mcu = MarkovChain(U, T)
mcx = MarkovChain(X, T)
elbos = zeros(0)

function norm(A::AbstractArray; p = 2)
    sum(abs.(A) .^ p)^(1 / p)
end

val = 1.0
for _ = 1:4
    # println(elbo(mcx, mcu))

    fillalphaX!(mcx, mcu, P1, Pt, Pe, hpmm.Y)
    fillbetaX!(mcx, mcu, Pt, Pe, hpmm.Y)
    VBMC.fillPtx!(mcx, mcu, Pt, Pe, hpmm.Y)

    fillalphaU!(mcu, mcx, P1, Pt, Pe, hpmm.Y)
    fillbetaU!(mcu, mcx, Pt, Pe, hpmm.Y)
    VBMC.fillPtu!(mcu, mcx, Pt, Pe, hpmm.Y)
    println(elbo(mcx, mcu) |> Float64)
    # if val < 1.0e-14
        # break
    # end
end

1.740431386962883
1.7438492967914212
1.744625859555342
1.7448022294635266


In [392]:
@test ((mcx.beta.mat[:,1] .* mcx.alpha.mat[:,1]) |> sum) |> Float64 ≈ ((mcx.beta.mat[:,2] .* mcx.alpha.mat[:,2]) |> sum) |> Float64
@test (mcx.beta.mat[:,1] .* mcx.alpha.mat[:,1]) |> sum |> Float64 ≈ 1.
@test (mcx.beta.mat[:,2] .* mcx.alpha.mat[:,2]) |> sum |> Float64 ≈ 1.

[91m[1mTest Failed[22m[39m at [39m[1mIn[392]:1[22m
  Expression: (mcx.beta.mat[:, 1] .* mcx.alpha.mat[:, 1] |> sum) |> Float64 ≈ (mcx.beta.mat[:, 2] .* mcx.alpha.mat[:, 2] |> sum) |> Float64
   Evaluated: 0.9999999999999998 ≈ 0.99999957637562



LoadError: [91mThere was an error during testing[39m

In [393]:
(mcu.alpha.mat[:,1] .* mcu.beta.mat[:,1] |> sum) - (mcu.alpha[:, 2] |> sum )

exp(-14.435951016536873)

In [394]:
mcu.alpha[:, 2] |> sum 

exp(-5.377076919832469e-7)

In [395]:
@test (mcu.beta.mat[:,1] .* mcu.alpha.mat[:,1]) |> sum |> Float64 ≈ (mcu.beta.mat[:,2] .* mcu.alpha.mat[:,2]) |> sum |> Float64
@test (mcu.beta.mat[:,1] .* mcu.alpha.mat[:,1]) |> sum  ≈ 1.
@test (mcu.beta.mat[:,2] .* mcu.alpha.mat[:,2]) |> sum  ≈ 1.

[91m[1mTest Failed[22m[39m at [39m[1mIn[395]:1[22m
  Expression: (mcu.beta.mat[:, 1] .* mcu.alpha.mat[:, 1] |> sum) |> Float64 ≈ (mcu.beta.mat[:, 2] .* mcu.alpha.mat[:, 2] |> sum) |> Float64
   Evaluated: 1.0000000000000002 ≈ 0.9999994622924526



LoadError: [91mThere was an error during testing[39m

In [275]:
(mcu.beta.mat[:,1] .* mcu.alpha.mat[:,1]) |> sum, (mcu.beta.mat[:,2] .* mcu.alpha.mat[:,2]) |> sum

(exp(1.7763568394002505e-15), exp(34.004843733842804))

In [292]:
VBMC.p1u(1, mcx.Z, P1, mcx.P1, Y[1], Pe)

exp(-451.6880146156002)

In [293]:
mcu.alpha.mat[:, 1] = 1:mcu.Z .|> (u -> VBMC.p1u(u, mcx.Z, P1, mcx.P1, Y[1], Pe))
mcu.alpha.mat[:, 1] .|> Float64

3-element Vector{Float64}:
 6.829481149007878e-197
 2.3317543299886406e-197
 1.3515123121225329e-196

In [277]:
(mcu.alpha[:,1] .* mcu.beta[:,1]) |> sum

exp(1.7763568394002505e-15)

In [53]:
function get_pdf(xs)
    total_prob = ULogarithmic(1.)
    for t in 1:T
        if t == 1
            p1 = 0.
            for u in 1:U
                p1 += P1[u,xs[t]] * Pe[u,xs[t],hpmm.Y[t]]
            end
            total_prob *= p1
        else
            pt = 0.
            for (u, uprev) in Iterators.product(1:U, 1:U)
                pt += Pt[u,xs[t],uprev,xs[t-1]] * Pe[u,xs[t],hpmm.Y[t]]
            end
            total_prob *= pt
        end
    end
    total_prob
end

probabilities = zeros(0) .|> ULogFloat64
maxprob = ULogFloat64(0.)
for x in 0:(X^T-1)
    xs = get_digits(x, X)
    total_prob = get_pdf(xs)
    println(xs, total_prob)
    if maxprob < total_prob
        maxprob = total_prob
    end
    append!(probabilities, total_prob)
end

[1, 1]+exp(-4.46746581556419)
[2, 1]+exp(-1.9132886267972715)
[1, 2]+exp(-4.683821365046635)
[2, 2]+exp(-2.7630895038877474)
