In [1]:
using RxInfer
using GraphViz

using Test
using TOML
using BenchmarkTools
using VBMC

using Distributions
using Random
using StatsBase
using LogarithmicNumbers

Random.seed!(123)
if "test" in readdir()
    cd("test")
end

In [2]:
vars = TOML.parsefile("Constants.toml")
T, U, X, Y = 2, vars["U"], vars["X"], vars["Y"]
# @test T == 2 && U == 3 && X == 2

P1 = begin
    Vector(1:U*X) |>
    Dirichlet |>
    rand |>
    (x -> Categorical(x)) |>
    (d -> ReshapedCategorical(d, U, X))
end

#P(·,· | u_prev, x_prev) = Pt[:,:,u_prev,x_prev] sums to 1 for fixed u_prev, x_prev
Pt = reshape(rand(Dirichlet(1:U*X), U * X), (U, X, U, X)) |> TransitionDistribution

function emissionf(randVal, u, x)
    randVal .+ x
end
function revf(emission, u, x)
    emission .- x
end #the dot vectorizes, could use X instead
emission = Emission(emissionf, revf)

Pe = EmissionDistribution{Continuous}(Normal(0, 1), emission, U, X)
dist = HpmmDistribution(P1, Pt, Pe)
hpmm = rand(dist, T)

TMM([2, 2], [1, 2], Real[1.9462520540394483, 0.400486398281505])

In [3]:
function engT2(mcx, mcu)
    energy = 0.

    energy += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]) * mcx.P1[x] * mcu.P1[u]
    ) |>
    sum

    energy2 = 0.
    for t in 2:T
        energy2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> log(Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
        ) |>
        sum
    end
    energy += energy2
    # println("energytotal ", energy |> Float64)
    energy
end

engT2 (generic function with 1 method)

In [None]:
function eng(mcx, mcu)
    energy = 0.

    energy += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]) * mcx.P1[x] * mcu.P1[u]
    ) |>
    sum

    # println("Energy 1 ", energy |> Float64)

    energy2 = 0.
    for t in 2:T
        energy2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> log(Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
        ) |>
        sum
    end
    energy += energy2
    # println("Energy 2 ", energy2 |> Float64)
    # println("energytotal ", energy |> Float64)
    energy
end
function entx(mc, mcu)
    negent = 0.
    negent += 1:mc.Z .|>
    (
        z -> log(mc.P1[z]) * mc.P1[z]
    ) |>
    sum
    negent2 = 0.
    for t in 2:T
        negent2 += Iterators.product(1:mc.Z, 1:mc.Z) .|> (
            ((z,zprev),) -> 
            begin
                qt = VBMC.qxxprev(z, zprev, mcu, t, hpmm.Y[t], Pt, Pe)
                mc.alpha[zprev,t-1] * qt * log(qt) * mc.beta[z, t]
            end
        ) |>
        sum
    end
    negent += negent2
    -negent
end
function entu(mc, mcx)
    negent = 0.
    negent += 1:mc.Z .|>
    (
        z -> log(mc.P1[z]) * mc.P1[z]
    ) |>
    sum
    negent2 = 0.
    for t in 2:T
        negent2 += Iterators.product(1:mc.Z, 1:mc.Z) .|> (
            ((z,zprev),) -> 
            begin
                qt = VBMC.quuprev(z, zprev, mcx, t, hpmm.Y[t], Pt, Pe)
                mc.alpha[zprev,t-1] * qt * log(qt) * mc.beta[z, t]
            end
        ) |>
        sum
    end
    negent += negent2
    -negent
end
function elbo(mcx, mcu)
    elbo = 0.

    elbo += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]/(mcx.P1[x] * mcu.P1[u])) * mcx.P1[x] * mcu.P1[u]
    ) |>
    sum

    println("Elbo1 ", elbo |> Float64)

    elbo2 = 0.
    for t in 2:T
        elbo2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> 
            begin
                qtu = VBMC.quuprev(u, uprev, mcx, t, hpmm.Y[t], Pt, Pe)
                qtx = VBMC.qxxprev(x, xprev, mcu, t, hpmm.Y[t], Pt, Pe)
                log(
                    Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]/(qtu * qtx)
                ) * mcx.Pt[x,xprev,t] * mcu.Pt[u,uprev,t]
            end
        ) |>
        sum
    end
    println("Elbo2 ", elbo2 |> Float64)
    elbo += elbo2
    # println("elbototal ", elbo |> Float64)
    elbo
end

elbo (generic function with 1 method)

In [5]:
function elbo2(mcx, mcu)
    elbo = 0.

    elbo += Iterators.product(1:U, 1:X) .|>
    (
        ((u,x),) -> log(P1[u,x]*Pe[u,x,hpmm.Y[1]]/(mcx.P1[x] * mcu.P1[u])) * mcx.P1[x] * mcu.P1[u]
    ) |>
    sum

    println("Elbo 1 ", elbo |> Float64)

    elbo2 = 0.
    for t in 2:T
        elbo2 += Iterators.product(1:U, 1:X, 1:U, 1:X) .|>
        (
            ((u,x,uprev,xprev),) -> 
            begin
                qtu = VBMC.quuprev(u, uprev, mcx, t, hpmm.Y[t], Pt, Pe)
                qtx = VBMC.qxxprev(x, xprev, mcu, t, hpmm.Y[t], Pt, Pe)
                log(
                    Pt[u,x,uprev,xprev]*Pe[u,x,hpmm.Y[t]]/(qtu * qtx)
                ) * mcx.alpha[xprev, t] * qtx * mcx.beta[x, t] * mcu.alpha[uprev, t] * qtu * mcu.beta[u, t]
            end
        ) |>
        sum
    end
    println("Elbo 2 ", elbo2 |> Float64)
    elbo += elbo2
    # println("elbototal ", elbo |> Float64)
    elbo
end

elbo2 (generic function with 1 method)

In [6]:
T = 2

mcu = MarkovChain(U, T)
mcx = MarkovChain(X, T)

function norm(A::AbstractArray; p = 2)
    sum(abs.(A) .^ p)^(1 / p)
end

val = 1.0
for _ = 1:2
    tmpP1, tmpPt = mcx.P1 |> deepcopy, mcx.Pt |> deepcopy

    fillalphaX!(mcx, mcu, P1, Pt, Pe, hpmm.Y)
    fillbetaX!(mcx, mcu, Pt, Pe, hpmm.Y)
    VBMC.fillPtx!(mcx, mcu, Pt, Pe, hpmm.Y)

    fillalphaU!(mcu, mcx, P1, Pt, Pe, hpmm.Y)
    fillbetaU!(mcu, mcx, Pt, Pe, hpmm.Y)
    VBMC.fillPtu!(mcu, mcx, Pt, Pe, hpmm.Y)

    signeda = mcx.Pt .|> Logarithmic
    signedb = tmpPt .|> Logarithmic
    val = norm(signeda .- signedb)

    println("elbomanual terms ", round(eng(mcx, mcu), digits=3), " ", round(VBMC.norm(mcu) * entx(mcx, mcu),digits=3), " ", round(VBMC.norm(mcx)*entu(mcu, mcx), digits=3))
    println("elbomanual ", Float64(eng(mcx, mcu) + VBMC.norm(mcu) * entx(mcx,mcu) + VBMC.norm(mcx) * entu(mcu,mcx)))
    println("elbofunction ", elbo(mcx, mcu) |> Float64)
    println("===")
    # println(total_elbo(mcu, mcx) |> Float64)
    if val < 1.0e-14
        break
    end
end

elbomanual terms -0.007 0.01 0.0
elbomanual 0.002282548278445081
elbofunction 0.01550808840596067
===
elbomanual terms -0.0 0.0 0.0
elbomanual 2.8350688199328317e-6
elbofunction 2.8733930811314854e-6
===


In [7]:
elbo2(mcx, mcu) |> Float64

Elbo 1 1.923283271415242e-6
Elbo 2 1.0201538962417292e-10


1.9233852868048645e-6

In [39]:
a = -1.2803196203480313e-8 + 7.51389305709447e-8 * (mcu.P1 |> sum) + 1.0986122880593954 * (mcx.P1 |> sum)
a |> Float64

6.644322043954148e-8

In [None]:
ent(mcu)

Entropy 1 1.0986122880593954
Entropy 2 -9.88751053354595


-exp(2.173489362030741)

In [None]:
b = -1.6360277518668169e-9 + 2.1338078139101344e-10 + -9.88751053354595

In [9]:
VBMC.norm(mcx) |> ULogFloat64 #, sum(mcx.alpha[:, 2] .* mcx.beta[:, 2] .|> Float64), sum(mcx.alpha[:, 1] .* mcx.beta[:, 1] .|> Float64)

exp(-17.09257918624566)

In [11]:
function get_digits(val, base) digits(val, base=base, pad=T) .|> x -> x+1 end

function get_qpdf(xs, mc)
    total_prob = 1.
    for t in 1:T
        if t == 1
            total_prob *= mc.P1[xs[t]] |> Float64
        else
            total_prob *= mc.Pt[xs[t],xs[t-1],t]/ (mc.alpha[xs[t-1],t]*mc.beta[xs[t-1],t]) |> Float64
        end
    end
    total_prob
end

function mcnorm(mc)
    summa = 0. 
    for z in 0:(mc.Z^T-1)
        zs = get_digits(z, mc.Z)
        total_prob = get_qpdf(zs, mc)
        summa += total_prob
    end
    summa
end

mcnorm (generic function with 1 method)

## energytotal - negenttotal - negenttotal = elbomanual

In [None]:
function get_pdf(xs)
    total_prob = ULogarithmic(1.)
    for t in 1:T
        if t == 1
            p1 = 0.
            for u in 1:U
                p1 += P1[u,xs[t]] * Pe[u,xs[t],hpmm.Y[t]]
            end
            total_prob *= p1
        else
            pt = 0.
            for (u, uprev) in Iterators.product(1:U, 1:U)
                pt += Pt[u,xs[t],uprev,xs[t-1]] * Pe[u,xs[t],hpmm.Y[t]]
            end
            total_prob *= pt
        end
    end
    total_prob
end

probabilities = zeros(0) .|> ULogFloat64
maxprob = ULogFloat64(0.)
for x in 0:(X^T-1)
    xs = get_digits(x, X)
    total_prob = get_pdf(xs)
    println(xs, total_prob)
    if maxprob < total_prob
        maxprob = total_prob
    end
    append!(probabilities, total_prob)
end

[1, 1]+exp(-4.46746581556419)
[2, 1]+exp(-1.9132886267972715)
[1, 2]+exp(-4.683821365046635)
[2, 2]+exp(-2.7630895038877474)
