In [11]:

using StatsBase
using LinearAlgebra
using Distributions
include("FHMMs.jl")

Y = rand(5,5)
h = rand(3,5,2)
test = Main.FHMMs.FHMM(5,3,2)
display(Main.FHMMs.update_variational_parameters(test, h, Y))
test.W

5×2×5 Array{Float64,3}:
[:, :, 1] =
 0.145711  1.21206
 0.28873   1.13872
 0.339681  1.05328
 0.323266  1.93457
 0.394549  0.839139

[:, :, 2] =
 0.316878  1.06391
 0.411446  0.771466
 0.279381  1.1509
 0.271214  1.71644
 0.345751  0.672053

[:, :, 3] =
 0.409832  1.23958
 0.737528  1.12121
 0.48437   1.14983
 0.507458  1.99199
 0.671766  0.822926

[:, :, 4] =
 0.824046  0.848249
 1.34362   0.585546
 0.696288  0.822015
 0.756988  1.32866
 1.03604   0.624546

[:, :, 5] =
 0.466381  0.771896
 0.542673  0.543611
 0.413145  0.574322
 0.383246  1.1531
 0.473229  0.698074



5×3×2 Array{Float64,3}:
[:, :, 1] =
 0.577751  0.290354    0.00821592
 0.575265  0.719328    0.20288
 0.484865  0.00278731  0.407989
 0.425979  0.138964    0.384693
 0.506222  0.308047    0.457901

[:, :, 2] =
 0.605271  0.540239  0.54354
 0.135333  0.292112  0.976094
 0.543871  0.907398  0.104589
 0.861569  0.955342  0.861992
 0.616815  0.138889  0.434763

In [8]:


function generate(model::Main.FHMMs.FHMM, timesteps::Int)
    X = zeros(size(model.W)[1], timesteps)
    # choose initial states based on π
    states = zeros(Int64, model.M, timesteps)
    for m=1:model.M
        idx = sample(weights(model.π[:,m]))
        states[m,1] = Int64(idx)
    end
    for t=1:timesteps    
        mu = zeros(model.D)
        for m=1:model.M
            mu += model.W[:,states[m,t],m]
        end
        n = MvNormal(mu, model.C)
        X[:,t] = rand(n)
        if t < timesteps
            for m=1:model.M
                nsp = model.P[states[m,t],:,m]
                #display(nsp)
                smpl = sample(weights(nsp))
                states[m,t+1] = smpl
            end
        end
    end
    return X,states
end



generate (generic function with 3 methods)

In [None]:
function sample_states(model::Main.FHMMs.FHMM, timesteps::Int, iter=1000)
    # we need to sample theta, KxMxT
    # θ[1,1,2] represents the probabilities of chain (1) being in state (1) at time 2, 
    # θ[3,1,2] represents the probabilities of chain (1) being in state (3) at time 2, etc
    # this is unconditional wrt the observation
    theta_sampled = zeros(model.K,model.M, timesteps)
    for i=1:iter
        theta_iter = zeros(model.K,model.M,timesteps)
        for m=1:model.M
            s = sample(weights(model.π[:,m]))
            theta_iter[s, m,1] += 1
        end
        for t=2:timesteps
            for m=1:model_actual.M
                lastState = findfirst((t) -> t == 1, theta_iter[:,m,t-1])
                s = sample(weights(model.P[lastState,:,m]))
                theta_iter[s,m,t] += 1
            end
        end
        theta_sampled .+= theta_iter
    end
    theta_sampled ./= iter
    return theta_sampled
end


In [None]:
function get_state_combinations(M::Int, K::Int)
    state_combinations = []
    m = 1
    stack = []
    while(true)
        if(m == M)
            if(size(stack)[1] != M - 1)
                break
            end
            for s=1:K
                push!(stack, s)
                push!(state_combinations, copy(stack))
                pop!(stack)
            end
            stack[end] +=1
            while(stack[end] > K)
                pop!(stack)
                if(size(stack)[1] == 0)
                    break
                end
                stack[end] +=1
                m -= 1
            end
        else
            push!(stack, 1)
            m += 1
        end
    end
    return state_combinations
end

In [None]:
# we want to calculate 
# P(Mt=St|Yt)
# = P(Yt|Mt=St).P(Mt=St|Mt-1=St-1) / P(Yt)
# This decoding is specific to the completely factorized VI model which uses a simplified P(Mt|Yt) that assumes 
# no dependence between timesteps. Viterbi decoding is therefore not required, since we don't need to find the max sequence,
# we just take the max P at each timestep.
# even though each chain is independent but the observation is conditional on *all* chains 
# this is because Yt is a sum of M1,M2,M3...
# i.e. P(Yt|M1t=S1t,M2t=.., M3t=..) != P(Yt|M1t=S1t)
# but we have θ, which is an estimation of P(M1t=S1t), P(M1t=S2t)...P(Mmt=Sst)
# so we need to find all configurations of M/S at this t
# cross product of (M1xS)x(M2xS)x..x(MkxS)
# if we are using the variational inference model, θ is KxMxT
# θ[1,1,2] represents p(chain (1) in state (1) at time 2) 
# θ[3,1,2] represents p(chain (1) in state (3) at time 2), etc
# this is unconditional wrt the observation
# i.e. P(Mt=St|M(t-1)=S(t-1))
# therefore P(Mt=St) is the product over dim(1) at T
function decode(model::Main.FHMMs.FHMM, X::Array{Float64,2}, θ::Array{Float64,3})
    
    D=size(X)[1]
    timesteps=size(X)[2]
    
    states = zeros(model.M,timesteps)
    for t=1:timesteps
        pmax = -1
        mmax = []
        stack = []
        m = 1
        M = model.M
        K = model.K
        while(true)
            if(m == M)
                if(size(stack)[1] != M - 1)
                    break
                end
                for s=1:K
                    push!(stack, s)
                    mu = zeros(D)
                    for i=1:M
                        mu += model.W[:, stack[i], i]
                    end
                    
                    p_yt_given_mt_in_s = pdf(MvNormal(mu, Array{Float64,2}(Hermitian(Symmetric(model.C)))), X[:,t])
                    p_mt_in_s = reduce(*, θ[stack,:,t])
                    unnormalized = p_yt_given_mt_in_s * p_mt_in_s
                    if unnormalized > pmax
                        pmax = unnormalized
                        mmax = copy(stack)
                    end
                    pop!(stack)
                end
                stack[end] +=1
                while(stack[end] > K)
                    pop!(stack)
                    if(size(stack)[1] == 0)
                        break
                    end
                    stack[end] +=1
                    m -= 1
                end
            else
                push!(stack, 1)
                m += 1
            end
        end
        states[:,t] = mmax
    end
    return states
end

In [None]:
function decode_full(model::Main.FHMMs.FHMM, X::Array{Float64,2})
    D=size(X)[1]
    timesteps=size(X)[2]
    state_combinations = get_state_combinations(model.M, model.K)
    state_probabilities = zeros(size(state_combinations)[1],timesteps)
    
    M = model.M
    K = model.K
    logP = log.(model.P)
    for t=1:timesteps
        for s_index=1:size(state_combinations)[1]
            state = state_combinations[s_index]
            mu = zeros(D)
            for i=1:model.M
                mu += model.W[:, state[i], i]
            end
            p_yt_given_mt_in_s = pdf(MvNormal(mu, Array{Float64,2}(Hermitian(Symmetric(model.C)))), X[:,t])
                   
            if(t == 1)
               p_mt_in_s = reduce(*, model.π[[CartesianIndex(state[i], i)  for i=1:model.M]])
            else
                # calculate the probability that chain M is in state K, P(Mt)
                # this is P(Mt|Mt-1, Yt), so P(Mt-1|Yt) * P(Mt|Mt-1)
                # P(Mt-1|Yt) is sum over K
                p_mt_in_s = 0
                for sc_last_index=1:size(state_combinations)[1]
                    sc_last = state_combinations[sc_last_index]
                    for m_last=1:M
                        for m_this=1:M
                            p_mt_in_s += exp(
                                state_probabilities[sc_last_index,t-1] + 
                                logP[sc_last[m_this], state[m_this], m_this])
                        end
                    end
                end
            end
            state_probabilities[s_index,t] = log(p_yt_given_mt_in_s) +  log(p_mt_in_s)
         end
    end
    max_probs = zeros(M, timesteps)
    i = 1
    for p in argmax(state_probabilities, dims=1)
        max_probs[:, p[2]] = state_combinations[p[1]]
        i += 1
    end
    return max_probs
end

#X, actual = generate(model_actual, 2)
#decoded = decode_full(model_actual,X)
#actual .== decoded

In [None]:
# setup our known model
# 
D=15
K=3
M=2
#T=5
W=cat( [ 1.0 49.0; -1.0 -49.0 ],[ 1.0 49.0; -1.0 -49.0 ], [ 1.0 49.0; -1.0 -49.0 ], dims=3)
π = [ 0.35 0.25 0.95; 0.65 0.75 0.05; ]
P = cat([ 0.3 0.7 ; 0.2 0.8 ; ],
    [ 0.5 0.5 ; 0.4 0.6 ;  ],
    [ 0.75 0.25 ; 0.89 0.11 ; ], dims=3)
W = rand(D,K,M)
π = rand(K,M)
π ./= sum(π,dims=(1))
P = rand(K, K, M)
P ./= sum(P, dims=2)
#C = rand(D,D)
C = Array{Float64}(Diagonal([0.1 for x in 1:D]))
model_actual = Main.FHMMs.FHMM(D,K,M,
    W, π, P, C)

X, actual = generate(model_actual, 4)



In [None]:
model_learn = Main.FHMMs.FHMM(D,K,M)
result = Main.FHMMs.fit_sv!(model_learn,X,maxiter=5000,verbose=true)    

In [None]:
result.γ

In [None]:
actual

In [None]:
timesteps=10
num_observations = 100

theta_sampled = sample_states(model_actual, timesteps, 1000)
j = 0
for i in 1:num_observations
    X,actual_states = generate(model_actual, timesteps)
    estimated_states = decode(model_actual,X, theta_sampled)
    j += sum(actual_states .== estimated_states)
end
j / (num_observations * timesteps * model_actual.M)

In [None]:
j = 0
for i in 1:num_observations
    X,actual_states = generate(model_actual, timesteps)
    estimated_states = decode_full(model_actual,X)
    j += sum(actual_states .== estimated_states)
end
j / (num_observations * timesteps * model_actual.M)

In [None]:
model_learn = Main.FHMMs.FHMM(D,K,M)
j = 0
num_observations = 50
for i=1:num_observations
    X,actual_states = generate(model_actual, timesteps)
    
    result = Main.FHMMs.fit!(model_learn,X,maxiter=100)    
    
    estimated_states = decode_full(model_learn,X)
    j += sum(actual_states .== estimated_states)
end
j / (num_observations * timesteps * model_actual.M)

In [None]:
timesteps=10
model_learn = Main.FHMMs.FHMM(D,K,M)
j = 0
num_observations = 100
for i=1:num_observations
    X,actual_states = generate(model_actual, timesteps)
    result = Main.FHMMs.fit!(model_learn,X,maxiter=100,fzero=true)    
    for i in 1:D
        model_learn.C[i, i] += 0.0001
    end
    #println(cholesky(Hermitian(Symmetric(model_learn.C))))
    estimated_states = decode(model_learn,X, result.θ)
    j += sum(actual_states .== estimated_states)
    println(j / (i * timesteps * model_actual.M))
end
j / (num_observations * timesteps * model_actual.M)