# This notebook demonstrates a Hidden markov modeling of transition probabilities from times series data

# load required packages

In [1]:
using GraphRecipes, Plots

# generate training data

In [2]:
mutable struct Edge
    to::Int64
    value::Float64
end

function msmsample(p)
  p_cum = cumsum(p) ./ sum(p)
  r = rand()
  for i = 1:length(p_cum)
    if r <= p_cum[i]
      return i
    end
  end
end

function msmgenerate(nframe, state_num, E, pi_i, emission)
    states = zeros(typeof(nframe), nframe)
    observations = zeros(typeof(nframe), nframe)

    states[1] = msmsample(pi_i)
    observations[1] = msmsample(emission[states[1], :])
    for iframe = 2:nframe
        tmp = []
        for edge in E[states[iframe - 1]]
            push!(tmp, edge.value)
        end
        states[iframe] = E[states[iframe - 1]][msmsample(tmp)].to
        observations[iframe] = msmsample(emission[states[iframe], :])
    end
    return states, observations
end

function msmforward(data_list, state_num, E, pi_i, emission)
    alpha_list = []
    alpha_sum_list = []
    for data in data_list
        alpha_sum = []
        data_size = size(data, 1)
        alpha = zeros(Float64, (data_size, state_num))
        alpha[1, :] = pi_i .* emission[:, data[1]]
        push!(alpha_sum, sum(alpha[1, :]))
        alpha[1, :] ./= alpha_sum[1]
        
        for frame in 2:data_size
            for state in 1:state_num
                for edge in E[state]
                    to = edge.to
                    trans_prob = edge.value
                    alpha[frame, to] += alpha[frame - 1, state] * trans_prob * emission[to, data[frame]]
                end
            end
            push!(alpha_sum, sum(alpha[frame, :]))
            alpha[frame, :] ./= alpha_sum[frame]
        end
        push!(alpha_list, alpha)
        push!(alpha_sum_list, alpha_sum)
    end
    
    return alpha_list, alpha_sum_list
end

function msmbackward(data_list, state_num, E, pi_i, emission, alpha_sum_list)
    beta_list = []
    data_num = length(data_list)
    for (data, data_id) in zip(data_list, 1:data_num)
        data_size = size(data, 1)
        beta = zeros(Float64, (data_size, state_num))
        # 全ての状態が最終状態となることを仮定
        beta[data_size, :] .= 1
        
        for frame in (data_size - 1):-1:1
            for state in 1:state_num
                for edge in E[state]
                    to = edge.to
                    trans_prob = edge.value
                    beta[frame, state] += beta[frame + 1, to] * trans_prob * emission[to, data[frame + 1]]
                end
            end
            beta[frame, :] ./= alpha_sum_list[data_id][frame + 1]
        end
        
        push!(beta_list, beta)
    end
    
    return beta_list
end

function msmbaumwelch(data_list, state_num, E, pi_i, emission)
    ## setup
    tolerance = 10.0^(-3)
    check_convergence = Inf64
    count_iteration = 0
    data_num = length(data_list)
    obs_num = length(emission[1, :])
    while check_convergence > tolerance
        alpha_list, alpha_sum_list = msmforward(data_list, state_num, E, pi_i, emission)
        beta_list = msmbackward(data_list, state_num, E, pi_i, emission, alpha_sum_list)
        
        E_nxt = deepcopy(E)
        pi_nxt = deepcopy(pi_i)
        emission_nxt = deepcopy(emission)
        
        diff_sum = 0
        for state in 1:state_num
            state_sum = 0
            for edge_id in 1:size(E_nxt[state], 1)
                to = E_nxt[state][edge_id].to
                E_nxt[state][edge_id].value = 0
                for (data, data_id) in zip(data_list, 1:data_num)
                    data_size = size(data, 1)
                    for frame in 1:(data_size - 1)
                        E_nxt[state][edge_id].value += (alpha_list[data_id][frame, state]
                                                    * E[state][edge_id].value
                                                    * emission[to, data[frame + 1]]
                                                    * beta_list[data_id][frame + 1, to]
                                                    / alpha_sum_list[data_id][frame + 1])
                    end
                end
                
                state_sum += E_nxt[state][edge_id].value
            end
            
            for edge_id in 1:size(E_nxt[state], 1)
                # println("$(state) $(edge_id) $(E_nxt[state][edge_id].value) $(state_sum)")
                E_nxt[state][edge_id].value /= state_sum
                # println("$(state) $(edge_id) $(E_nxt[state][edge_id].value) $(state_sum)")
                diff_sum += abs(E[state][edge_id].value - E_nxt[state][edge_id].value)
            end
            
        end
        
        E = E_nxt
        pi_i = pi_nxt
        emission = emission_nxt
        check_convergence = diff_sum
        count_iteration += 1
        
        if count_iteration % 20 == 0
            println("count: ", count_iteration, " diff_sum: ", diff_sum)
        end
    end
    
    return E, pi_i, emission
end


msmbaumwelch (generic function with 1 method)

In [3]:
# ground-truth transition probabilities
state_num = 5
E = [[] for i in 1:state_num]
for i in 1:state_num
    for j in 1:state_num
        if i == j
            push!(E[i], Edge(j, 0.6))
        end
        if abs(i - j) == 1 || state_num - abs(i - j) == 1
            push!(E[i], Edge(j, 0.2))
        end
    end
end
E

5-element Array{Array{Any,1},1}:
 [Edge(1, 0.6), Edge(2, 0.2), Edge(5, 0.2)]
 [Edge(1, 0.2), Edge(2, 0.6), Edge(3, 0.2)]
 [Edge(2, 0.2), Edge(3, 0.6), Edge(4, 0.2)]
 [Edge(3, 0.2), Edge(4, 0.6), Edge(5, 0.2)]
 [Edge(1, 0.2), Edge(4, 0.2), Edge(5, 0.6)]

In [4]:
# equilibrium probabilities
pi_i = [
0.2, 
0.2, 
0.2,
0.2,
0.2
]

5-element Array{Float64,1}:
 0.2
 0.2
 0.2
 0.2
 0.2

In [5]:
# emission probabilities
emission = [
0.7 0.2 0.05 0.05;
0.2 0.7 0.05 0.05;
0.25 0.25 0.25 0.25;
0.05 0.05 0.7 0.2;
0.05 0.05 0.2 0.7
]

5×4 Array{Float64,2}:
 0.7   0.2   0.05  0.05
 0.2   0.7   0.05  0.05
 0.25  0.25  0.25  0.25
 0.05  0.05  0.7   0.2
 0.05  0.05  0.2   0.7

In [6]:
observations = []
for i in 1:1
    state, observation = msmgenerate(10000, state_num, E, pi_i, emission)
    push!(observations, observation)
end

# Hidden markov modeling

In [7]:
# initial transition probabilities
E0 = [[] for i in 1:state_num]
for i in 1:state_num
    for j in 1:state_num
        if i == j
            push!(E0[i], Edge(j, rand()))
        end
        if abs(i - j) == 1 || state_num - abs(i - j) == 1
            push!(E0[i], Edge(j, rand()))
        end
    end
end

In [8]:
@time E1, emission1, pi_i1 = msmbaumwelch(observations, state_num, E0, pi_i, emission);

count: 20 diff_sum: 0.09985931734119688
count: 40 diff_sum: 0.025454057668619967
count: 60 diff_sum: 0.005962296976204706
count: 80 diff_sum: 0.0018366372684995302
 17.459352 seconds (449.44 M allocations: 7.535 GiB, 8.65% gc time)


In [9]:
T = zeros(state_num, state_num)
for state in 1:state_num
    for edge in E[state]
        T[state, edge.to] = edge.value
    end
end

T0 = zeros(state_num, state_num)
for state in 1:state_num
    for edge in E0[state]
        T0[state, edge.to] = edge.value
    end
end

T1 = zeros(state_num, state_num)
for state in 1:state_num
    for edge in E1[state]
        T1[state, edge.to] = edge.value
    end
end

In [10]:
T

5×5 Array{Float64,2}:
 0.6  0.2  0.0  0.0  0.2
 0.2  0.6  0.2  0.0  0.0
 0.0  0.2  0.6  0.2  0.0
 0.0  0.0  0.2  0.6  0.2
 0.2  0.0  0.0  0.2  0.6

In [11]:
T0

5×5 Array{Float64,2}:
 0.110465  0.381743  0.0       0.0        0.276573
 0.494781  0.828755  0.237999  0.0        0.0
 0.0       0.878591  0.930254  0.0831753  0.0
 0.0       0.0       0.7748    0.121435   0.419835
 0.240393  0.0       0.0       0.909144   0.00981252

In [12]:
T1

5×5 Array{Float64,2}:
 0.611634  0.190939  0.0       0.0       0.197427
 0.221852  0.586464  0.191684  0.0       0.0
 0.0       0.215783  0.555128  0.229089  0.0
 0.0       0.0       0.227566  0.5527    0.219734
 0.201069  0.0       0.0       0.225802  0.573129

In [13]:
alpha_list_seica, alpha_sum_list = msmforward(observations, state_num, E, pi_i, emission);

In [14]:
alpha_list_seica[1]

10000×5 Array{Float64,2}:
 0.04        0.04        0.2        0.16        0.56
 0.019802    0.00990099  0.110011   0.136414    0.723872
 0.0263958   0.00530844  0.0792605  0.579169    0.309866
 0.00972754  0.00299899  0.101411   0.734393    0.151469
 0.00571842  0.00374074  0.162169   0.305903    0.522469
 0.0145211   0.00478658  0.106383   0.171287    0.703022
 0.273389    0.172257    0.225236   0.120369    0.20875
 0.053938    0.0455957   0.217412   0.499843    0.183211
 0.0130925   0.0136796   0.200713   0.25475     0.517765
 0.16901     0.264138    0.322256   0.109772    0.134825
 0.121738    0.603708    0.225184   0.026417    0.0229523
 0.0707168   0.15386     0.465448   0.0933656   0.21661
 0.0886357   0.531264    0.312551   0.0365938   0.0309557
 ⋮                                              
 0.0145529   0.00547767  0.121655   0.65217     0.206144
 0.277759    0.0474518   0.397377   0.177532    0.0998798
 0.530697    0.126406    0.273904   0.0398106   0.029183
 0.694646    0.1

In [15]:
beta_list_seica = msmbackward(observations, state_num, E, pi_i, emission, alpha_sum_list)

1-element Array{Any,1}:
 [0.3465290480755708 0.14181223435230914 … 1.1075061436598397 1.2192882307733697; 0.11234308370912711 0.13283563194042788 … 1.9577957680826894 0.8664236646399961; … ; 0.5654693671692604 0.2827346835846302 … 0.9738639101248373 1.4765033476086245; 1.0 1.0 … 1.0 1.0]

In [16]:
beta_list_seica[1]

10000×5 Array{Float64,2}:
 0.346529  0.141812  0.60232   1.10751    1.21929
 0.112343  0.132836  0.929089  1.9578     0.866424
 0.150994  0.143859  0.681815  1.2534     0.694753
 0.391947  0.32329   0.894065  0.993359   1.15557
 0.364056  0.720761  1.57126   1.10594    0.769618
 1.07813   3.08734   3.326     1.45437    0.521494
 0.242343  0.575431  1.76061   2.09911    0.888176
 0.597771  1.16267   1.74576   0.847058   0.610224
 2.4784    5.81498   2.95311   0.493525   0.327473
 0.829253  1.66179   1.08664   0.324096   0.260733
 0.846152  0.952669  1.21444   0.793738   1.19443
 1.53021   2.37048   0.944102  0.125454   0.350502
 1.84522   1.25109   0.480022  0.107553   0.57576
 ⋮                                        
 9.21931   4.1702    0.663755  0.176767   3.13838
 2.68438   1.24232   0.233781  0.0645477  0.911915
 1.56762   0.768147  0.194616  0.0558183  0.529241
 1.23831   0.699438  0.265376  0.0765546  0.409542
 1.09061   0.867482  0.468924  0.128485   0.337082
 0.941416  1.59739

In [17]:
using Revise; using MDToolbox

└ @ Revise /Users/seica/.julia/packages/Revise/BqeJF/src/Revise.jl:1328


In [18]:
@time T2, emission2, pi_i2 = MDToolbox.msmbaumwelch(observations, T0, pi_i, emission);

100 iteration LogLikelihood = -1.295930e+04  delta = 6.278742e-04  tolerance = 1.000000e-04
 23.258108 seconds (163.75 M allocations: 6.613 GiB, 5.45% gc time)


In [19]:
T2

5×5 Array{Float64,2}:
 0.611591  0.190909  0.0       0.0       0.1975
 0.221615  0.585245  0.19314   0.0       0.0
 0.0       0.219319  0.54832   0.232361  0.0
 0.0       0.0       0.228641  0.55216   0.219198
 0.201115  0.0       0.0       0.22547   0.573415

In [20]:
logL, alpha_list, factor_list = MDToolbox.msmforward(observations, T, pi_i, emission)

([-12962.24892369659], Any[[0.04000000000000001 0.04000000000000001 … 0.16000000000000003 0.5599999999999999; 0.019801980198019802 0.009900990099009903 … 0.13641364136413645 0.7238723872387239; … ; 0.021311737447726044 0.016315046009013158 … 0.5563859772669336 0.21977060289482342; 0.009425144808162552 0.008057106358057913 … 0.2607624503718512 0.5440485496279979]], Any[[0.25, 0.3636, 0.30049504950495054, 0.4054076514735493, 0.321155376111529, 0.37419015620230833, 0.10993437633647256, 0.22269535870472107, 0.29835548058773, 0.13507406776695563  …  0.25868842620149746, 0.3522299019024262, 0.4006492640107289, 0.4237415168501674, 0.2946951095952086, 0.402899962969975, 0.11901235877584344, 0.24261502940063395, 0.29489023001439507, 0.3183196304710191]])

In [21]:
alpha_list[1]

10000×5 Array{Float64,2}:
 0.04        0.04        0.2        0.16        0.56
 0.019802    0.00990099  0.110011   0.136414    0.723872
 0.0263958   0.00530844  0.0792605  0.579169    0.309866
 0.00972754  0.00299899  0.101411   0.734393    0.151469
 0.00571842  0.00374074  0.162169   0.305903    0.522469
 0.0145211   0.00478658  0.106383   0.171287    0.703022
 0.273389    0.172257    0.225236   0.120369    0.20875
 0.053938    0.0455957   0.217412   0.499843    0.183211
 0.0130925   0.0136796   0.200713   0.25475     0.517765
 0.16901     0.264138    0.322256   0.109772    0.134825
 0.121738    0.603708    0.225184   0.026417    0.0229523
 0.0707168   0.15386     0.465448   0.0933656   0.21661
 0.0886357   0.531264    0.312551   0.0365938   0.0309557
 ⋮                                              
 0.0145529   0.00547767  0.121655   0.65217     0.206144
 0.277759    0.0474518   0.397377   0.177532    0.0998798
 0.530697    0.126406    0.273904   0.0398106   0.029183
 0.694646    0.1

In [22]:
logL2, beta_list = MDToolbox.msmbackward(observations, factor_list, T, pi_i, emission)

([-12962.24892369659], Any[[0.3465290480755833 0.14181223435231424 … 1.1075061436598796 1.2192882307734136; 0.11234308370913118 0.13283563194043263 … 1.95779576808276 0.8664236646400274; … ; 0.5654693671692604 0.2827346835846302 … 0.9738639101248373 1.4765033476086245; 1.0 1.0 … 1.0 1.0]])

In [23]:
beta_list[1]

10000×5 Array{Float64,2}:
 0.346529  0.141812  0.60232   1.10751    1.21929
 0.112343  0.132836  0.929089  1.9578     0.866424
 0.150994  0.143859  0.681815  1.2534     0.694753
 0.391947  0.32329   0.894065  0.993359   1.15557
 0.364056  0.720761  1.57126   1.10594    0.769618
 1.07813   3.08734   3.326     1.45437    0.521494
 0.242343  0.575431  1.76061   2.09911    0.888176
 0.597771  1.16267   1.74576   0.847058   0.610224
 2.4784    5.81498   2.95311   0.493525   0.327473
 0.829253  1.66179   1.08664   0.324096   0.260733
 0.846152  0.952669  1.21444   0.793738   1.19443
 1.53021   2.37048   0.944102  0.125454   0.350502
 1.84522   1.25109   0.480022  0.107553   0.57576
 ⋮                                        
 9.21931   4.1702    0.663755  0.176767   3.13838
 2.68438   1.24232   0.233781  0.0645477  0.911915
 1.56762   0.768147  0.194616  0.0558183  0.529241
 1.23831   0.699438  0.265376  0.0765546  0.409542
 1.09061   0.867482  0.468924  0.128485   0.337082
 0.941416  1.59739

In [24]:
alpha_list[1] - alpha_list_seica[1]

10000×5 Array{Float64,2}:
  0.0           0.0           0.0           0.0           0.0
 -3.46945e-18   0.0           0.0           0.0           0.0
  0.0           8.67362e-19   0.0          -1.11022e-16   0.0
  0.0           0.0          -1.38778e-17  -1.11022e-16   0.0
  8.67362e-19   4.33681e-19   0.0           0.0          -1.11022e-16
  3.46945e-18   8.67362e-19   4.16334e-17   5.55112e-17   0.0
 -5.55112e-17   5.55112e-17   2.77556e-17  -2.77556e-17  -2.77556e-17
 -6.93889e-18   6.93889e-18   5.55112e-17  -5.55112e-17   0.0
  1.73472e-18   0.0           0.0           0.0           0.0
  0.0           0.0           0.0           0.0           0.0
  0.0           0.0           0.0           3.46945e-18   0.0
  0.0           5.55112e-17   0.0           1.38778e-17   2.77556e-17
  0.0           0.0           0.0           0.0           0.0
  ⋮                                                      
 -5.20417e-18  -1.73472e-18  -2.77556e-17  -1.11022e-16  -2.77556e-17
 -5.55112e-17  -

In [25]:
beta_list[1] - beta_list_seica[1]

10000×5 Array{Float64,2}:
  1.25455e-14   5.10703e-15   2.15383e-14   3.9968e-14    4.39648e-14
  4.06619e-15   4.7462e-15    3.34177e-14   7.06102e-14   3.13083e-14
  5.46785e-15   5.16254e-15   2.45359e-14   4.50751e-14   2.52021e-14
  1.42664e-14   1.17129e-14   3.20854e-14   3.58602e-14   4.17444e-14
  1.31006e-14   2.58682e-14   5.63993e-14   3.9968e-14    2.78666e-14
  3.84137e-14   1.0969e-13    1.18128e-13   5.21805e-14   1.86517e-14
  8.74301e-15   2.06501e-14   6.35048e-14   7.54952e-14   3.20854e-14
  2.15383e-14   4.19664e-14   6.28386e-14   3.05311e-14   2.19824e-14
  8.9706e-14    2.10498e-13   1.06581e-13   1.76525e-14   1.17684e-14
  2.9754e-14    5.9952e-14    3.90799e-14   1.16573e-14   9.32587e-15
  3.04201e-14   3.43059e-14   4.37428e-14   2.86438e-14   4.30767e-14
  5.50671e-14   8.4821e-14    3.41949e-14   4.52416e-15   1.25455e-14
  6.59472e-14   4.50751e-14   1.72085e-14   3.8719e-15    2.07612e-14
  ⋮                                                      
  0.0 

# visualization

In [None]:
using GraphRecipes, Plots
pyplot()

In [None]:
graphplot(T,
          markersize = 0.2,
          node_weights = pi_i,
          markercolor = :white,
          names = 1:size(T, 1),
          fontsize = 10,
          linecolor = :darkgrey,
          nodeshape = :circle,
          edgewidth = T, 
          self_edge_size = 0.0, 
          arrow = true
          )

In [None]:
graphplot(T1,
          markersize = 0.2,
          node_weights = pi_i,
          markercolor = :white,
          names = 1:size(T, 1),
          fontsize = 10,
          linecolor = :darkgrey,
          nodeshape = :circle,
          edgewidth = T,
          arrow = true
          )