In [1]:
import numpy as np
RED, GREEN, BLUE = 0, 1, 2

In [64]:
class HMM:
    def __init__(self):
        self.states = [(x,y)  for y in range(3) for x in range(3)]
        self.state_colors = np.array([GREEN, RED, BLUE, RED, BLUE, GREEN, BLUE, GREEN, RED])
        self.obs_prob = np.array([[0.8, 0.1, 0.1], [0.1, 0.6, 0.2], [0.1, 0.3, 0.7]])
        self.trans = np.array([
            [0.5, 0.25, 0, 0.25, 0, 0, 0, 0, 0],
            [0.25, 0.25, 0.25, 0, 0.25, 0, 0, 0, 0],
            [0, 0.25, 0.5, 0, 0, 0.25, 0, 0, 0],
            [0.25, 0, 0, 0.25, 0.25, 0, 0.25, 0, 0],
            [0, 0.25, 0, 0.25, 0, 0.25, 0, 0.25, 0],
            [0, 0, 0.25, 0, 0.25, 0.25, 0, 0, 0.25],
            [0, 0, 0, 0.25, 0, 0, 0.5, 0.25, 0],
            [0, 0, 0, 0, 0.25, 0, 0.25, 0.25, 0.25],
            [0, 0, 0, 0, 0, 0.25, 0, 0.25, 0.5],
        ])
        self.initial = np.array([1 / len(self.states)] * len(self.states))
        
    def viterbi(self, obs):
        delta = self.initial * self.obs_prob[self.state_colors, obs[0]]
        psis = []

        for i in range(1, len(obs)):
            psis.append(np.argmax(self.trans * delta, axis=-1))
            delta = self.obs_prob[self.state_colors, obs[i]] * np.max(self.trans * delta, axis=-1)

        p = np.max(delta)
        x = [np.argmax(delta)]
        for psi in reversed(psis):
            x.insert(0, psi[x[0]])

        return p, x
    
    def alpha(self, obs):
        alpha = self.initial * self.obs_prob[self.state_colors, obs[0]]
        
        for i in range(1, len(obs)):
            alpha = self.obs_prob[self.state_colors, obs[i]] * np.sum(self.trans * alpha, axis=-1)
            
        return alpha
        
    def forward(self, obs):
        return np.sum(self.alpha(obs))
    
    def filtering(self, obs):
        alpha = self.alpha(obs)
        return alpha / np.sum(self.alpha(obs))
        
    def beta(self, obs):
        beta = self.initial
        
        for i in reversed(range(-1, len(obs) - 1)):
            beta = np.sum(self.trans * (self.obs_prob[self.state_colors, obs[i + 1]] * beta), axis=0)
            
        return beta
        
    def smoothing(self, obs, index):
        alpha = self.alpha(obs[:index + 1])
        beta = self.beta(obs[index + 1:])
        return alpha * beta / np.sum(alpha * beta)
    
    def gamma(self, obs, index):
        return self.alpha(obs[:index + 1]) * self.beta(obs[index + 1:]) / self.forward(obs)
    
    def xi(self, obs, index):
        return self.alpha(obs[:index + 1]) * self.trans * self.obs_prob[self.state_colors, obs[index + 1]] * self.beta(obs[index + 2:])  / self.forward(obs)
    
    def baum_welsh(self, obs):
        
        gamma = np.array([self.gamma(obs, i) for i in range(len(obs))])
        xi = np.array([self.xi(obs, i) for i in range(len(obs) - 1)])
        
        self.initial = gamma[0]
        self.trans = np.sum(xi, axis=0) / np.sum(gamma, axis=0)
        

In [90]:
hmm = HMM()

In [91]:
hmm.viterbi(np.array([RED, GREEN, GREEN, RED, RED, RED]))

(6.400000000000001e-05, [8, 5, 5, 8, 8, 8])

In [92]:
hmm.filtering(np.array([RED]))

array([0.03333333, 0.26666667, 0.03333333, 0.26666667, 0.03333333,
       0.03333333, 0.03333333, 0.03333333, 0.26666667])

In [93]:
hmm.smoothing(np.array([RED, GREEN]), 0)

array([0.11764706, 0.15686275, 0.05882353, 0.15686275, 0.05882353,
       0.11764706, 0.05882353, 0.11764706, 0.15686275])

In [94]:
hmm.forward(np.array([RED, RED, RED, RED]))

0.01432118055555556

In [103]:
hmm.baum_welsh(np.array([RED, RED, RED, RED]))

In [104]:
hmm.forward(np.array([RED, RED, RED, RED]))

26.481912946503698

In [106]:
hmm.trans

array([[7.73762985e-03, 2.46948782e-01, 0.00000000e+00, 2.46948782e-01,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00],
       [3.86881493e-03, 2.46948782e-01, 3.84651336e-03, 0.00000000e+00,
        3.70791800e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00],
       [0.00000000e+00, 2.46948782e-01, 7.69302672e-03, 0.00000000e+00,
        0.00000000e+00, 3.31839061e-04, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00],
       [3.86881493e-03, 0.00000000e+00, 0.00000000e+00, 2.46948782e-01,
        3.70791800e-03, 0.00000000e+00, 3.84651336e-03, 0.00000000e+00,
        0.00000000e+00],
       [0.00000000e+00, 2.46948782e-01, 0.00000000e+00, 2.46948782e-01,
        0.00000000e+00, 3.31839061e-04, 0.00000000e+00, 3.31839061e-04,
        0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 3.84651336e-03, 0.00000000e+00,
        3.70791800e-03, 3.31839061e-04, 0.00000000e+00, 0.00000000e+00,
        2.3