# Example: Hidden Markov Model

In this example, we aim to test the scale factor update rules for Categorical messages through Transition node and Equality node in ReactiveMP. All the messages follow the Sum-Product update rule. 
We expect the negative-log scale factor on every edge is identical and equivalent to the Bethe Free-Energy computed by ReactiveMP. 

The model specification is as follows:
$$ z_t \sim Cat(Az_{t-1}), $$
$$ y_t \sim Cat(Bz_t) $$ 
where $z_t$ and $y_t$ are the state and the observation at time $t$, $A$ and $B$ are known matrices. We perform state smoothing in this example. 

In [1]:
using ReactiveMP, Rocket, GraphPPL, Distributions
using LinearAlgebra, Random


## Generate data

In [44]:
n_samples = 5
A = [0.6 0.1 0.2; 0.3 0.7 0.3; 0.1 0.2 0.5]; # Transition probabilities
B = [0.8 0.25 0.1; 0.1 0.5 0.6; 0.1 0.25 0.3]; # Observation noise

z_0_data = [1.0, 0.0, 0.0]; # Initial state

In [45]:
function onehot_vec(dist::Categorical)
    k = ncategories(dist); # get the number of categories
    x = zeros(k);
    x[rand(dist)] = 1.0;

    return x
end

function generate_data(n_samples; seed = 12)
    Random.seed!(seed)
    z = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoded state
    y = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoded observation
    z_prev = z_0_data;
    for t = 1:n_samples
        z[t] = onehot_vec(Categorical(A*z_prev))
        y[t] = onehot_vec(Categorical(B*z[t]))
        z_prev = z[t]
    end
    
    return y
end

generate_data (generic function with 1 method)

In [46]:
y_data = generate_data(n_samples);

## Inference by ReactiveMP (involving Scale Factor)

In [47]:
@rule Transition(:out, Marginalisation) (m_in::Categorical, m_a::PointMass) = begin
    return Categorical((mean(m_a)*probvec(m_in)) ./ sum(mean(m_a)*probvec(m_in)))
end

@rule Transition(:out, Marginalisation) (m_in::Categorical, m_a::PointMass, meta::ScaleFactorMeta) = begin 
    message = @call_rule Transition(:out, Marginalisation) (m_in = m_in, m_a = m_a)
    scalefactor = 0.0
    return ScaledMessage(message, scalefactor)
end

@rule Transition(:out, Marginalisation) (m_in::ScaledMessage, m_a::PointMass, meta::ScaleFactorMeta) = begin 
    A = mean(m_a)
    message = @call_rule Transition(:out, Marginalisation) (m_in = m_in.message, m_a = m_a)
    scalefactor = m_in.scale 
    return ScaledMessage(message, scalefactor)
end

@rule Transition(:in, Marginalisation) (m_out::PointMass, m_a::PointMass, meta::ScaleFactorMeta) = begin 
    A = mean(m_a)
    message = Categorical((A' * probvec(m_out)) ./ sum(A' * probvec(m_out)))
    scalefactor = -log(sum(A' * probvec(m_out)))
    return ScaledMessage(message, scalefactor)
end

@rule Transition(:in, Marginalisation) (m_out::ScaledMessage, m_a::PointMass, meta::ScaleFactorMeta) = begin 
    A = mean(m_a)
    message = Categorical((A' * probvec(m_out.message)) ./ sum(A' * probvec(m_out.message)))
    scalefactor = m_out.scale - log(sum(A' * probvec(m_out.message)))

    return ScaledMessage(message, scalefactor)
end


In [48]:
#Product function for equality node
function ReactiveMP.prod(::ProdAnalytical, left::ScaledMessage{ <: Categorical }, right::ScaledMessage{ <:Categorical })
    mean_left = probvec(left.message)
    mean_right = probvec(right.message)

    message = prod(ProdAnalytical(),left.message,right.message)
    scalefactor = left.scale + right.scale - log(mean_left' * mean_right)

    return ScaledMessage(message,scalefactor)
end

### Define Model 

In [49]:
@model [default_meta = ScaleFactorMeta() ] function HMM_scf(n, A, B)
    #define variables
    z = randomvar(n)
    y = datavar(Vector{Float64},n)

    cA = constvar(A)
    cB = constvar(B)
    # define initial state
    z_init ~ Categorical([1/3, 1/3, 1/3]) 

    z_prev = z_init

    for t=1:n
        z[t] ~ Transition(z_prev, cA) 
        y[t] ~ Transition(z[t], cB) 
        z_prev = z[t]
    end

    return z, y
end

HMM_scf (generic function with 1 method)

In [50]:
function inference(data, A, B)
    n = length(data)
    model, (z,y) = HMM_scf(n, A, B)

    z_mar = keep(Vector{Marginal})

    z_subscript = subscribe!(getmarginals(z), z_mar)

    update!(y, data)

    unsubscribe!(z_subscript)

    return z_mar
end

inference (generic function with 1 method)

### Do inference

In [51]:
z_marginals_sf = inference(y_data, A, B);

In [52]:
# Compare the (neg-log) scale factor at every edge. We expect they are similar.
for i=1:n_samples
println(z_marginals_sf.values[][i].data.scale )
end

5.797739353741411
5.797739353741411
5.797739353741411
5.79773935374141
5.79773935374141


## Inference by regular ReactiveMP

In [53]:
contingency_matrix(distribution::Contingency) = distribution.p

contingency_matrix (generic function with 1 method)

In [54]:
# Additional rules for Transition node
@rule Transition(:in, Marginalisation) (m_out::PointMass, m_a::PointMass, ) = begin 
    return Categorical((mean(m_a)' * probvec(m_out)) ./ sum(mean(m_a)' * probvec(m_out)))
end

@rule Transition(:in, Marginalisation) (m_out::Categorical, m_a::PointMass, ) = begin 
    return Categorical((mean(m_a)' * probvec(m_out)) ./ sum(mean(m_a)' * probvec(m_out)))
end

# Additional marginal rules for Transition node
@marginalrule Transition(:out_in_a) (m_out::Categorical, m_in::Categorical, m_a::PointMass, ) = begin 
    B = Diagonal(probvec(m_out)) * mean(m_a) * Diagonal(probvec(m_in))
    return (out_in = Contingency(B ./ sum(B)), a = m_a)
end

@marginalrule Transition(:out_in_a) (m_out::PointMass, m_in::Categorical, m_a::PointMass, ) = begin 
    b = clamp.(mean(m_a)' * probvec(m_out), tiny, Inf);
    return (out = m_out, in = prod(ProdAnalytical(), Categorical(b ./ sum(b)), m_in), a = m_a)
end

# additional marginal rules for Categorical node
@marginalrule Categorical(:out_p) (m_out::Categorical, m_p::PointMass, ) = begin 
    return (out = prod(ProdAnalytical(), Categorical(mean(m_p)), m_out), p = m_p)
end


In [55]:
@average_energy Transition (q_out::Any, q_in::Any, q_a::PointMass) = begin
    return -probvec(q_out)' * mean(log,q_a) * probvec(q_in)
end

@average_energy Transition (q_out_in::Contingency, q_a::PointMass) = begin
    return -tr(contingency_matrix(q_out_in)' * mean(log, q_a))
end

In [56]:
@model function HMM(n, A, B)
    #define variables
    z = randomvar(n)
    y = datavar(Vector{Float64},n)

    cA = constvar(A)
    cB = constvar(B)
    # define initial state
    z_init ~ Categorical([1/3, 1/3, 1/3])

    z_prev = z_init

    for t=1:n
        z[t] ~ Transition(z_prev, cA)
        y[t] ~ Transition(z[t], cB) 
        z_prev = z[t]
    end

    return z, y
end

HMM (generic function with 1 method)

In [57]:
function inference(data, A, B)
    n = length(data)
    model, (z,y) = HMM(n, A, B)

    z_mar = keep(Vector{Marginal})
    FE = keep(Float64)

    z_subscript = subscribe!(getmarginals(z), z_mar)
    fe_sub = subscribe!(score(Float64,BetheFreeEnergy(), model), FE)

    update!(y, data)

    unsubscribe!((z_subscript, fe_sub))

    return z_mar, FE
end

inference (generic function with 1 method)

In [58]:
z_mar, bfe = inference(y_data, A, B);

In [59]:
bfe.values[]

5.797739353741412

In [60]:
sf = z_marginals_sf.values[][1].data.scale

5.797739353741411

## Compare the results between the ScaleFactor-involved and the regular ReactiveMP

In [61]:
# Compare the BFE of the regular ReactiveMP with the negative-log scalefactor on every edge. 
for i=1:n_samples
    println(isapprox(bfe.values[], z_marginals_sf.values[][i].data.scale))
end

true
true
true
true
true


In [62]:
# Compare the marginal results of the Scalefactor-involved ReactiveMP and the regular ReactiveMP
for i=1:n_samples
    println(isapprox(z_mar.values[][i].data.p, z_marginals_sf.values[][i].data.message.p))
end

true
true
true
true
true
