In [406]:
import numpy as np
from tqdm import tqdm

class MultinomialHMM():
    
    def __init__(self, n_components=1):
        
        # 状态集合
        self.Q = None
        # 观测集合
        self.V = None
        # 状态序列
        self.I = None
        # 观测序列
        self.O = None
        # 状态集合个数
        self.n_components = n_components
        
        return None
    
    
    def forward(self, t, i):
        T =  t
        N = len(self.Q)
        
        alpha = self.start_prob * self.visit_trans[:,self.O[0]]
        for t in range(T):
            temp = np.empty(N)
            for j in range(N):
                prob = np.dot(alpha, self.state_trans[:,j]) * self.visit_trans[j,self.O[t+1]]
                temp[j] = prob
            alpha = temp
            
        return alpha[i]
    
    
    def backward(self, t, i):
        T = len(self.O)
        N = len(self.Q)
        
        beta = np.ones(N) 
        for t in range(T-2, t-1, -1):
            temp = np.empty(N)
            for j in range(N):
                prob = np.sum(self.state_trans[j,:] * self.visit_trans[:,self.O[t+1]] * beta)
                temp[j] = prob
            beta = temp
            
        return beta[i]
       
        
    def gamma(self, t, i):
        N = len(self.Q)
        x1 = self.forward(t,i) * self.backward(t,i)
        x2 = np.sum([self.forward(t,j) * self.backward(t,j) for j in range(N)])
        
        return x1/x2
    
    
    def xi(self, t, i, j):
        N = len(self.Q)
        x1 = self.forward(t,i) * self.state_trans[i,j] * self.visit_trans[j,self.O[t+1]] * self.backward(t+1,j)
        x2 = 0
        for i in range(N):
            for j in range(N):
                x2 += self.forward(t,i) * \
                      self.state_trans[i,j] * self.visit_trans[j,self.O[t+1]] * self.backward(t+1,j)
                
        return x1/x2
        
        
    def baum_welch(self):
        N = len(self.Q)
        M = len(self.V)
        T = len(self.O)
        
        state_trans = np.empty([N,N])
        for i in range(N):
            for j in range(N):
                x1 = np.sum([self.xi(t,i,j) for t in range(T-1)]) 
                x2 = np.sum([self.gamma(t,i) for t in range(T-1)])
                state_trans[i,j] = x1 / x2
                
        visit_trans = np.empty([N,M])
        for j in range(N):
            for k in range(M):
                x1 = 0
                for t in range(T):
                    if self.O[t] == self.V[k]:
                        x1 += self.gamma(t,j)
                x2 = np.sum([self.gamma(t,j) for t in range(T)])
                visit_trans[j,k] = x1 / x2
                
        start_prob = np.array([self.gamma(0,i) for i in range(N)])   
        
        params = {'state_trans': state_trans,
                  'visit_trans': visit_trans,
                  'start_prob': start_prob}
        
        return params
    
    
    def update(self, params):
        self.state_trans = params['state_trans']
        self.visit_trans = params['visit_trans']
        self.start_prob = params['start_prob']
        
        return None
    
    def fit(self, X_train=None, n_iterations=1):
        
        # 状态集合
        self.Q = np.arange(self.n_components)
        # 观测集合
        self.V = np.unique(X_train)
        # 观测序列
        self.O =  X_train
        # 初始化
        self.init_params(N=self.n_components, M=len(self.V))
        
        for i in tqdm(range(n_iterations)):
            params = self.baum_welch()
            self.update(params)
        
        return None
    
    
    def init_params(self, N, M):
        
        self.state_trans = np.random.random((N,N)) / N
        self.visit_trans = np.random.random((N,M)) / M
        self.start_prob = np.random.random(N) / N
        
        return None

### Test

In [407]:
A = np.array([[0.9, 0.05, 0.05],
              [0.1, 0.1, 0.8],
              [0.1, 0.7, 0.2]])

B = np.array([[0.1, 0.9],
              [0.2, 0.8],
              [0.7, 0.3]])

Pi = np.array([0.2, 0.4, 0.4])

I = np.array([0])

In [408]:
T = 5000

start = round(2*np.random.random())
I = np.array([start])
for i in range(T):
    t_1 = A[I[-1], 0]
    t_2 = t_1 + A[I[-1], 1] 
    prob = np.random.random()
    if prob < t_1:
        next_state = 0
    elif prob < t_2:
        next_state = 1
    else:
        next_state = 2
    I = np.append(I, next_state)
    
O = np.empty(T)
for t in range(T):
    prob = np.random.random()
    if prob < B[I[t],0]:
        O[t] = 0
    else:
        O[t] = 1