# PyTorch Implementation of HMM's Baum-welch's algorithm

# Imports

In [123]:
import torch
import numpy as np

from itertools import chain 
from collections import Counter

# Utils

In [6]:
def expected_output_occurrence(index, step_gammas, summed_gammas):
    filtered_gamma = step_gammas.index_select(1, torch.LongTensor(index))
    sum_filtered_gamma = filtered_gamma.sum(dim = 1)
    new_obs_prob = torch.div(sum_filtered_gamma, summed_gammas)
    return new_obs_prob

# Build Model

## Estimate initial parameters

$\theta=(A, B, \pi)$ can be set with random initializations. However, they can also be set using prior information which can speed up the algorithm and steer it towards convergance of desired local maximum.  The following algorithms are used to estimate the initial conditions given prior information about the observed sequence and hidden conditions.

**Initial state distribution**

$\pi_i = \frac {count(z_1 = i)} {N}$

In [78]:
def hidden_state_init(sequence):
    probabilities = {state: state/len(sequence) for state in Counter(sequence).values()}
    hidden_state_initial_probabilities = np.array([prob[1] for prob in sorted(probabilities.items())])
    hidden_state_initial_probabilities = torch.from_numpy(hidden_state_initial_probabilities)
    return hidden_state_initial_probabilities

**Transition matrix**

$A(i, j) = P(Z_{t+1} = j | Z_t = i) = \frac {coun(i \rightarrow j)} {count(i)}$

In [77]:
def trans_mat_init(sequence):
    temp = []
    sequences = []
    sequence_counts = Counter(sequence)
    seq_set_length = len(set(sequence))
    shape = [seq_set_length, seq_set_length]
    transition_mat = torch.zeros(size = shape, dtype = torch.float64)

    for i, val in enumerate(sequence):

        temp.append(val)

        if i != 0:
            sequences.append([tuple(temp)])
            temp = []
            temp.append(val)

    transition_sequences = Counter(chain(*sequences))

    transition_mat_dict = {key: val/sequence_counts[key[0]] for key, val in transition_sequences.items()}

    for key, val in transition_mat_dict.items():
        transition_mat[key[0]][key[1]] = val
    
    return transition_mat

**Emission matrix**

$B(j, k) = \frac {count(z=j \:\land\: x=k)} {count(z=j)}$

In [121]:
def emiss_state_init(hid_seq, obs_seq):
    sequence_counts = Counter(hid_seq)
    hid_emi_seq = [[x] for x in zip(hid_seq, obs_seq)]
    hid_emi_seq_counts = Counter(chain(*hid_emi_seq))
    emiss_mat = torch.zeros(size = [len(set(hid_seq)), len(set(obs_seq))], dtype = torch.float64)
    emiss_mat_dict = {key:hid_emi_seq_counts[key]/sequence_counts[key[0]] for key, val in hid_emi_seq_counts.items()}
    for key, val in emiss_mat_dict.items():
        emiss_mat[key[0]][key[1]] = val

    return emiss_mat

## Expectation Procedure

### Forward Algorithm

$ \alpha_i(t) = P(Y_1 = y_1,...Y_t = y_t, X_t = i | \theta) = $
1. Initialization step = $\alpha_i(1) = \pi_i b_i(y_1) $
1. Induction step = $\alpha_i(t+1) = b_i(y_{t+1}) \sum_{j=1}^{N} \alpha_j(t)a_{ij} $

In [51]:
def forward(emission_matrix, log = True):  

    # α(1,i) = π(i) * B(i, 1)
    alpha_initial = pi * emission_matrix[:, obs_sequence[0]]
    
    if log == True:
        # Scaling of initialization step
        alpha[:, 0] = torch.div(alpha_initial, alpha_initial.sum())
    else:
        alpha[:, 0] = alpha_initial
        
    # Induction steps: bi(y, t+1) * Σ a(j, t) * a(i, j)
    for i, obs in enumerate(obs_sequence[1:]):
        
        # α(t,i) * A(i, j), where α(t,i) = π(i) * B(i, t)
        current_probability = torch.matmul(alpha[:, i], A)   
        
        # Forward probability 
        forward_probability = torch.mul(current_probability, emission_matrix[:, obs])
        
        if log == True:
            # Scaling & update forward matrix
            alpha[:, i+1] = torch.div(forward_probability, forward_probability.sum())
        else:
            alpha[:, i+1] = forward_probability     
                                      
    return alpha

### Backward Algorithm

$ \beta_i(t) = P(Y_{t+1} = y_{t+1},...Y_T = y_T | X_t = i , \theta) = $
1. Initialization step = $ \beta_i(T) = 1$
2. Induction step = $\beta_i(t) = {\sum_{j=1}^{N}}b_j(t+1)a_{ij}b_j(y_{t+1})$

In [57]:
def backward(observation, log = True):
                
    # Initialization
    beta[:, -1] = torch.from_numpy(np.array([1.0, 1.0]))
    
    # Induction steps
    for i, obs in enumerate(observation[:0:-1]):
        
        if log == True:
            # Induction: Σ A(i, j) * P(X(t+1) | Z) * β(t+1, j)   
            _beta = torch.matmul(emission_matrix[:, obs] *  A, beta[:, -(i+1)])
            beta[:, -(i+2)] = torch.div(_beta, _beta.sum())
        else:
            beta[:, -(i+2)] = torch.matmul(emission_matrix[:, obs] *  A, beta[:, -(i+1)])
    return beta

## Maximization Procedure

### Gammas

$\gamma_i(t) = P(X_t = i | Y, \theta) = \frac {\alpha_i(t)  \beta_i (t)}  {\sum_{j}^{N}\alpha_i(t) \beta_i(t)}$

In [34]:
def calculate_gammas(alpha, beta):
    
    # α(t, i) * B(t, i)
    numerator = torch.mul(alpha, beta)
        
    # Σ(j, 1-->N) α(t, i) * B(t, i)
    denomenator = torch.sum(numerator, dim = 0)
        
    # γ(t, i) = α(t, i) * B(t, i) / Σ(j, 1-->N) α(t, i) * B(t, i)
    gamma_i = torch.div(numerator, denomenator)
    
    return gamma_i

### Zetas

$ \zeta_{ij}(t) = P(X_t = i, X_{t+1} = \frac {P(X_t = i, X_{t+1} = j, Y | \theta)} {P( Y|\theta)} = j | Y, \theta) = \frac {\alpha_i (t) a_{ij} \beta_j(t+1) b_j(y_{t+1})} {{\sum_{i=1}^{N}} \sum_{j=1}^{N} \alpha_i (t) a_{ij} \beta_j(t+1) b_j(y_{t+1})}$

In [35]:
def calculate_zetas(alpha, obs_seq):

    zetas = []
    for t, fwd in enumerate(alpha.transpose(1, 0)):
        
        # α(t,i) * A(i, j)
        x = fwd*torch.transpose(A, 1, 0)
        
        # α(t,i) * A(i, j) * B(t+1, j) * β(t+1, j) 
        numerator = torch.transpose(x, 1, 0) * emission_matrix[:, obs_seq[t+1]] * beta[:, t+1]
        
        # P(Y|θ) = Σ(i, 1-->N) Σ(j, 1-->N) (α(t,i ) * A(i, j) * B(t+1, j) * β(t+1, j)
        denomenator = torch.sum(numerator, dim=0).sum(dim=0)

        # ζt(i, j) = α(t,i) * A(i, j) * B(t+1, j) * β(t+1, j) / Σ(i, 1-->N) Σ(j, 1-->N) (α(t,i ) * A(i, j) * B(t+1, j) * β(t+1, j)
        zeta = torch.div(numerator, denomenator)
        
        zetas.append(zeta)
    
    # Σ(1-->T-1) ζt(i, j)
    summed_zetas = torch.stack(zetas, dim = 0).sum(dim = 0)

    return summed_zetas

### Re-estimate parameters

**Hidden state distribution**

$ \pi_i^* = \gamma_i(1) $

**Transition matrix**

$a_{ij}^* = \frac {\sum_{t=1}^{T-1} \zeta_{ij}(t)} {\sum_{t=1}^{T-1} \gamma_i(t)}$

**Emission matrix**

$b_i^*(\nu_k) = \frac {1_{y_t = \nu_k} \gamma_i(t)} {\sum_{t=1}^{T} \gamma_i(t)}$

where

$ 1_{y_t = \nu_k} = \begin{cases}
    1 & \text{if } y_t = \nu_k\\
    0              & \text{otherwise}
\end{cases} $

In [36]:
def re_estimate_parameters(emission_matrix, alpha, beta):
    
    # γ(t, i)
    step_gammas = calculate_gammas(alpha, beta)

    ##################################################
    # Re-estimate initial probabilities
    ##################################################
    
    new_pi = step_gammas[:, 0]
    
    ##################################################
    # Re-estimate transition matrix
    ##################################################
    
    # Σ(1-->T-1) ζt(i, j)
    summed_zetas = calculate_zetas(alpha[:, :-1], obs_sequence)
    # Σ(1-->T-1) γ(t, i)
    summed_gammas = torch.sum(step_gammas[:, :-1], dim = 1)
    # a^(i, j) = Σ(1-->T-1) ζt(i, j) / Σ(1-->T-1) γ(t, i)
    new_transition_matrix = torch.div(summed_zetas, summed_gammas.view(-1, 1))

    ##################################################
    # Re-estimate emission matrix
    ##################################################

    # Σ(1-->T) γ(t, i)
    summed_gammas = torch.sum(step_gammas, dim = 1)
    state_indices = [np.where(obs_sequence == searchval)[0] for searchval in set(obs_sequence)]
    new_emission_matrix = [expected_output_occurrence(value, step_gammas, summed_gammas) for value in state_indices]
    new_emission_matrix = torch.stack(new_emission_matrix, dim = 0).transpose(1, 0)

    return new_pi, new_transition_matrix, new_emission_matrix

# Test Model

To validate the implementation of the above algorithms we will use a numeric example in the following [link](http://www.cs.rochester.edu/u/james/CSC248/Lec11.pdf), and the results of the first pass will be compared as a form of sanity check.

### Numeric example: 1 pass

In [63]:
# Observation sequence
obs_sequence = np.array([0, 1, 2, 2])

# Initial state distribution
pi = np.array([0.8, 0.2])
pi = torch.from_numpy(pi)

# Initial transition matrix
A = np.array([[0.6, 0.4], [0.3, 0.7]])
A = torch.from_numpy(A)

# Initial emission matrix
emission_matrix = np.array([[0.3, 0.4, 0.3], [0.4, 0.3, 0.3]])
emission_matrix = torch.from_numpy(emission_matrix)

# Initialization of alpha & beta tensors
shape = [A.shape[0], len(obs_sequence)]
alpha = torch.zeros(shape, dtype = torch.float64)
beta = torch.zeros(shape, dtype = torch.float64)

### Expectation step

$ \alpha_i(t) $

In [64]:
alpha = forward(emission_matrix, log = False)

In [65]:
alpha

tensor([[0.2400, 0.0672, 0.0162, 0.0045],
        [0.0800, 0.0456, 0.0176, 0.0056]], dtype=torch.float64)

$ \beta_i(t) $

In [66]:
beta = backward(obs_sequence, log = False)

In [67]:
beta

tensor([[0.0324, 0.0900, 0.3000, 1.0000],
        [0.0297, 0.0900, 0.3000, 1.0000]], dtype=torch.float64)

### Maximization step

$\gamma_i(t) $

In [68]:
gammas = calculate_gammas(alpha, beta)

In [69]:
gammas

tensor([[0.7660, 0.5957, 0.4787, 0.4436],
        [0.2340, 0.4043, 0.5213, 0.5564]], dtype=torch.float64)

$ \zeta_{ij}(t) $

In [70]:
zetas = calculate_zetas(alpha[:, :-1], obs_sequence)

In [71]:
zetas

tensor([[1.1553, 0.6851],
        [0.3628, 0.7968]], dtype=torch.float64)

**Re-estimated parameters**

In [72]:
new_pi, new_transition_matrix, new_emission_matrix = re_estimate_parameters(emission_matrix, alpha, beta)

$ \pi_i^* $

In [73]:
new_pi

tensor([0.7660, 0.2340], dtype=torch.float64)

$ a_{ij}^* $

In [74]:
new_transition_matrix

tensor([[0.6277, 0.3723],
        [0.3128, 0.6872]], dtype=torch.float64)

$ b_i^*(\nu_k) $

In [75]:
new_emission_matrix

tensor([[0.3354, 0.2608, 0.4038],
        [0.1364, 0.2356, 0.6280]], dtype=torch.float64)