In [1]:
using Distributions
using Seaborn
import StatsBase: wsample
import Clustering: kmeans

In [2]:
# a bit of overkill really with these guys
const ϵ = eps(1.0)
const log_ϵ = -1_000.0
const ϵI = UniformScaling(1e-6)
;

In [3]:
function make_model(K, D)
    π0 = rand(K)
    π0 ./= sum(π0)
    
    # p(zᵗ⁺¹ | z₁ᵗ, z₂ᵗ) = P[zᵗ⁺¹, z₁ᵗ, z₂ᵗ]
    # Z₁ ≡ Z₂
    P = rand(K, K, K)
    P ./= sum(P, 1)
    
    μs = [randn(D) * 8 for _ in 1:K]
    Σs = [eye(D) * rand() for _ in 1:K]
    
    model = Dict(:π0 => π0, :P => P,
        :μs => μs, :Σs => Σs)
    
    return model
end

make_model (generic function with 1 method)

In [4]:
K = 3
KK = K^2
D = 1
T = 500

model = make_model(K, D)
sqrtm_Σs = map(sqrtm, model[:Σs])
;

In [22]:
Z = Matrix{Int}(2, T)
X1 = Matrix{Float64}(D, T)
X2 = similar(X1)

z = wsample(1:K, model[:π0], 2)

for t in 1:T
    z = wsample(1:K, model[:P][:, z...], 2)
    Z[:, t] = z
    X1[:, t] = sqrtm_Σs[z[1]] * randn(D) + model[:μs][z[1]]
    X2[:, t] = sqrtm_Σs[z[2]] * randn(D) + model[:μs][z[2]]
end

## EM

In [6]:
@inline function sanitize_log!(A::AbstractArray{<:Real}, log_ϵ::Float64=log_ϵ)
    @inbounds for i in 1:length(A)
        !isfinite(A[i]) && (A[i] = log_ϵ)
    end
end

sanitize_log! (generic function with 2 methods)

In [7]:
@inline function sanitize!(A::AbstractArray{<:Real}, ϵ::Float64=ϵ)
    @inbounds for i in 1:length(A)
        (!isfinite(A[i]) || A[i] ≤ ϵ) && (A[i] = ϵ)
    end
end

sanitize! (generic function with 2 methods)

In [8]:
# blatantly copied from github.com/jwmi/HMM
"""
    logsumexp(X::AbstractArray{<:Real})

log Σᵢ exp(xᵢ) for a sequence where xᵢ = log yᵢ
∴ logsumexp(x) = log(Σᵢ yᵢ )
"""
@inline function logsumexp(X::AbstractArray{<:Real})
    mx = maximum(X)

    if !isfinite(mx)
        return mx
    else
        return mx + log(sum(exp, X .- mx))
    end
end

logsumexp

In [23]:
# initial EM estiamtes 
R = kmeans([X1 X2], K, maxiter=100, display=:none)
ms = [R.centers[:, i] for i in 1:K]
Ss = [eye(D) * 3 for _ in 1:K]

# initial distribition over K × K
p0 = normalize(rand(K), 1)
p0 = (p0 * p0')[:]
log_p0 = log.(p0)

P= rand(K, KK)
P./= sum(P, 1)
# make it K² × K²
# T[:, i] is symmetric if reshaped to be K×K, ∀i
P = repeat(P, inner=(K, 1)) .* repeat(P, outer=(K, 1))
log_P = log.(P)
;

In [24]:
log_b = zeros(KK, T)
log_α = similar(log_b)
log_β = similar(log_b)

α = similar(log_b)
β = similar(log_b)
γ = similar(log_b)
# add them all together (vs a list)
ξ = zeros(KK, KK)
;



In [59]:
for t in 1:T
    for q in 1:K^2
        (i, j) = ind2sub((K, K), q)
        log_b[q, t] = logpdf(MvNormal(ms[i], Ss[i]), X1[:, t]) +
                logpdf(MvNormal(ms[j], Ss[j]), X2[:, t])
    end
end

In [26]:
#
# forward
#
# temp is [ log_a[i, t-1] + log_A[j, i] ]ᵢ
temp = zeros(KK)
log_α[:, 1] = log_p0 .+ log_b[:, 1]

for t in 2:T
    for j in 1:KK
        temp[:] = log_α[:, t-1]
        temp .+= log_P[j, :]
        log_α[j, t] = logsumexp(temp) + log_b[j, t]
    end
end

In [27]:
#
# backward pass
#
# temp is [ log_A[i, j] + log_β[j, t+1] + log_b[j, t+1] ]ⱼ
log_β[:, T] = 0

for t in (T-1):-1:1
    for i in 1:KK
        temp[:] = log_b[:, t+1]
        temp .+= log_P[:, i]
        temp .+= log_β[:, t+1]
        log_β[i, t] = logsumexp(temp)
    end
end

In [68]:
#
# γ
#
for t in 1:T
    γ[:, t] = log_α[:, t]
    γ[:, t] .+= log_β[:, t]
    γ[:, t] .-= logsumexp(γ[:, t])
end
map!(exp, γ, γ)

pseudo_counts = squeeze(sum(γ, 2), 2)
;

In [77]:
#
# ξ
#
temp2 = similar(ξ)
for t in 1:(T-1)
    temp2[:] = log_P
    temp2 .+= log_α[:, t]'
    temp2 .+= log_b[:, t+1]
    temp2 .+= log_β[:, t+1]
    temp2 .-= logsumexp(temp2)
    ξ .+= exp.(temp2)
end
ξ ./= sum(ξ, 1)
;

In [83]:
#
# parameter update
#

P[:] = ξ
log_P[:] = P
map!(log, log_P, log_P)

p0[:] = γ[:, 1]

9×9 Array{Float64,2}:
 -3.73248   -0.947106  -1.38392  -1.47889  …  -1.95965  -4.44827   -1.39973
 -3.44805   -1.94358   -2.36658  -2.55273     -2.78463  -3.59912   -3.3847 
 -2.44209   -2.55874   -1.88095  -1.9914      -1.8293   -2.77582   -1.63669
 -3.67896   -1.66262   -2.69442  -2.01042     -3.13105  -3.42078   -3.69352
 -3.28045   -3.21029   -3.29182  -3.3868      -3.72081  -2.87526   -5.52158
 -2.333     -3.49977   -3.1401   -2.70112  …  -2.95908  -2.00504   -3.81728
 -2.3968    -2.48621   -1.83825  -1.92272     -1.79078  -2.69875   -1.58981
 -2.31529   -3.67847   -3.02261  -2.93092     -2.82061  -2.06509   -3.77647
 -0.677924  -3.77798   -1.91497  -2.09311     -1.23559  -0.732271  -1.39802

In [71]:
Z_obs = similar(Z)
for t in 1:T
    Z_obs[:, t] = [ind2sub((K, K), indmax(γ[:, t]))...]
end
sum(Z_obs .!= Z)

614