In [213]:
from sklearn.preprocessing import OrdinalEncoder

from MyHMM import *

%run MyHMM.py

In [214]:
A = np.array([[0.5, 0.2, 0.3],
              [0.3, 0.5, 0.2],
              [0.2, 0.3, 0.5]])  # 状态集合Q={1, 2, 3}
B = np.array([[0.5, 0.5],
              [0.4, 0.6],
              [0.7, 0.3]])
pi = np.array([[0.2, 0.4, 0.4]]).T
visible_seq_init = np.array(['红', '白', '红'], dtype=object).reshape(-1, 1)

enc = OrdinalEncoder(categories=[['红', '白']])
visible_seq = enc.fit_transform(visible_seq_init).astype(np.int32).reshape(-1, )

In [215]:
# 前向算法测试
forward_test = MyHMM(hidden_status_num=3, visible_status_num=2, pi=pi, A=A, B=B)
print(forward_test.forward(visible_seq=visible_seq))
print(forward_test.forward(visible_seq=visible_seq, want_t=1))
print(forward_test.forward(visible_seq=visible_seq, want_t=2))
print(forward_test.forward(visible_seq=visible_seq, want_t=3))

0.130218
[[0.1 ]
 [0.16]
 [0.28]]
[[0.077 ]
 [0.1104]
 [0.0606]]
[[0.04187 ]
 [0.035512]
 [0.052836]]


In [216]:
# 后向算法测试
backward_test = MyHMM(hidden_status_num=3, visible_status_num=2, pi=pi, A=A, B=B)
print(backward_test.backward(visible_seq=visible_seq))
print(backward_test.backward(visible_seq=visible_seq, want_t=1))
print(backward_test.backward(visible_seq=visible_seq, want_t=2))
print(backward_test.backward(visible_seq=visible_seq, want_t=3))

0.130218
[[0.2451]
 [0.2622]
 [0.2277]]
[[0.54]
 [0.49]
 [0.57]]
[[1.]
 [1.]
 [1.]]


In [217]:
# \gamma计算测试
gamma_test = MyHMM(hidden_status_num=3, visible_status_num=2, pi=pi, A=A, B=B)
print(gamma_test.gamma_t(visible_seq=visible_seq, want_t=1))
print(gamma_test.gamma_t(visible_seq=visible_seq, want_t=2))
print(gamma_test.gamma_t(visible_seq=visible_seq, want_t=3))

[[0.18822283]
 [0.32216744]
 [0.48960973]]
[[0.31931069]
 [0.41542644]
 [0.26526287]]
[[0.32153773]
 [0.27271191]
 [0.40575036]]


In [218]:
# \xi计算测试
xi_test = MyHMM(hidden_status_num=3, visible_status_num=2, pi=pi, A=A, B=B)
result_t_1 = xi_test.xi_t(visible_seq=visible_seq, t=1)
print(result_t_1)
print(result_t_1[0, 1])  # t=1时刻状态为0, t=2时刻状态为1的概率
print(xi_test.xi_t(visible_seq=visible_seq, t=2))

[[0.1036723  0.04515505 0.03939548]
 [0.09952541 0.18062019 0.04202184]
 [0.11611298 0.1896512  0.18384555]]
0.04515504768925955
[[0.14782903 0.04730529 0.12417638]
 [0.12717136 0.16956181 0.11869327]
 [0.04653735 0.05584481 0.16288071]]


In [219]:
vitiver_test = MyHMM(hidden_status_num=3, visible_status_num=2, pi=pi, A=A, B=B)
best_hidden_status, bset_hidden_status_pro = vitiver_test.viterbi(visible_seq=visible_seq)
print("最优路径:", best_hidden_status, '最优路径概率:', bset_hidden_status_pro)

最优路径: [2, 2, 2] 最优路径概率: 0.014699999999999998


In [220]:
O = [
    [1, 2, 3, 0, 1, 3, 4],
    [1, 2, 3],
    [0, 2, 4, 2],
    [4, 3, 2, 1],
    [3, 1, 1, 1, 1],
    [2, 1, 3, 2, 1, 3, 4]]
I = O

In [221]:
hmm_supervision = MyHMM(hidden_status_num=5, visible_status_num=5)
hmm_supervision.supervision(visible_seq=O, hidden_seq=I)
print(hmm_supervision.pi)
print(hmm_supervision.A)
print(hmm_supervision.B)

[[0.16666667]
 [0.33333333]
 [0.16666667]
 [0.16666667]
 [0.16666667]]
[[9.99999950e-09 1.25000000e-01 1.66666667e-01 1.66666665e-09
  3.33333328e-09]
 [9.99999950e-09 3.74999999e-01 3.33333332e-01 4.99999998e-01
  3.33333328e-09]
 [9.99999950e-09 3.74999999e-01 1.66666665e-09 3.33333332e-01
  3.33333331e-01]
 [9.99999960e-01 1.25000000e-01 3.33333332e-01 1.66666665e-09
  6.66666659e-01]
 [9.99999950e-09 1.24999999e-09 1.66666667e-01 1.66666667e-01
  3.33333328e-09]]
[[9.99999980e-01 1.24999999e-09 1.66666665e-09 1.66666665e-09
  4.99999988e-09]
 [4.99999988e-09 9.99999995e-01 1.66666665e-09 1.66666665e-09
  4.99999988e-09]
 [4.99999988e-09 1.24999999e-09 9.99999993e-01 1.66666665e-09
  4.99999988e-09]
 [4.99999988e-09 1.24999999e-09 1.66666665e-09 9.99999993e-01
  4.99999988e-09]
 [4.99999988e-09 1.24999999e-09 1.66666665e-09 1.66666665e-09
  9.99999980e-01]]


In [222]:
hmm_no_supervision = MyHMM(hidden_status_num=5, visible_status_num=5)
hmm_no_supervision.baum_welch(O[0] + O[1] + O[2] + O[3] + O[4] + O[5])
print(hmm_no_supervision.pi)
print(hmm_no_supervision.A)
print(hmm_no_supervision.B)



[[1.0000000e+00]
 [3.5341444e-96]
 [0.0000000e+00]
 [0.0000000e+00]
 [0.0000000e+00]]
[[0.00000000e+00 6.01853628e-34 4.00128909e-01 3.89201723e-49
  9.97962862e-01]
 [6.24736842e-01 1.17858880e-41 0.00000000e+00 1.82054273e-01
  1.28994031e-32]
 [0.00000000e+00 1.00000000e+00 0.00000000e+00 4.98834561e-27
  2.03713792e-03]
 [0.00000000e+00 4.68387157e-50 2.00064455e-01 8.17945727e-01
  0.00000000e+00]
 [3.75263158e-01 0.00000000e+00 3.99806636e-01 0.00000000e+00
  1.61782974e-32]]
[[0.00000000e+000 7.18008670e-001 0.00000000e+000 0.00000000e+000
  0.00000000e+000]
 [1.00000000e+000 1.02467967e-001 1.85715792e-001 0.00000000e+000
  3.63846770e-001]
 [0.00000000e+000 3.02806776e-045 2.96870396e-137 1.00000000e+000
  3.51319550e-083]
 [0.00000000e+000 2.76817291e-046 3.25754025e-001 0.00000000e+000
  6.36153230e-001]
 [3.03965112e-062 1.79523363e-001 4.88530183e-001 0.00000000e+000
  2.40834687e-063]]
