In [None]:
import numpy as np
import torch
import beer
import pickle

In [None]:
def baum_welch_forward_test(init_states, trans_mat, lhs):
    init_prob = 1 / len(init_states)
    alphas = torch.zeros_like(lhs)
    scale_factors = torch.zeros(len(lhs)).type(lhs.type())
    obsev = lhs[0, init_states] * init_prob
    scale_factors[0] = obsev.sum()
    alphas[0, init_states] = obsev / scale_factors[0]
    for i in range(1, lhs.shape[0]):
        obsev = lhs[i] * (trans_mat.t() @ alphas[i-1])
        scale_factors[i] = obsev.sum()
        alphas[i] = obsev / scale_factors[i]        
    return alphas.log(), scale_factors

def baum_welch_backward_test(final_states, trans_mat, lhs, scale_factors):
    final_prob = 1 / len(final_states)
    betas = torch.zeros_like(lhs)
    betas[-1, final_states] = final_prob
    
    for i in reversed(range(lhs.shape[0]-1)):
        obsev = trans_mat @ (lhs[i+1] * betas[i+1])
        betas[i] = obsev / scale_factors[i+1]
    return betas.log()

In [None]:
def reverse_tensor(tensor):
    idx = torch.arange(len(tensor) - 1, -1, -1, dtype=torch.long)
    return tensor[idx]

def baum_welch_forward_backward_test(init_states, final_states, trans_mat, llhs):
    log_scale_factor = llhs.sum()
    scaled_llhs = llhs / log_scale_factor
    scaled_llhs = torch.exp(llhs - beer.logsumexp(scaled_llhs, dim=0)) + 1e-6
    lhs = torch.exp(scaled_llhs)
    log_alphas, scale_factors = baum_welch_forward_test(init_states, trans_mat, lhs)
    log_betas = baum_welch_backward_test(final_states, trans_mat, lhs, scale_factors)
    
    scale_alpha = torch.cumsum(scale_factors.log(), dim=-1)
    scale_beta = torch.zeros_like(scale_alpha)
    scale_beta[:-1] = reverse_tensor(torch.cumsum(reverse_tensor((scale_factors.log())), dim=-1))[1:]
    log_alphas += (scale_alpha)[:, None]
    log_betas += scale_beta[:, None]
    
    return log_alphas, log_betas   

In [None]:
feats = np.load('./recipes/timit/data/train_10utt/feats.npz')
labs = np.load('./recipes/timit/data/train_10utt/phones.int.npz')
keys = list(feats.keys())
#with open('./recipes/timit/exp/emission.mdl', 'rb') as m:
#    normals = pickle.load(m)
    
normals = beer.NormalDiagonalCovarianceSet.create(torch.zeros(13), torch.ones(13), ncomp=117, noise_std=10)

ft = torch.cat([torch.from_numpy(feats[keys[0]]).float()] * 1) 
lab = labs[keys[0]]
init_states = torch.tensor([0])
final_states = torch.tensor([len(lab) - 1])
trans_mat = beer.HMM.create_ali_trans_mat(len(lab)).double()
aliset = beer.AlignModelSet(normals, lab)
hmm = beer.HMM.create(init_states, final_states, trans_mat, aliset)

In [None]:
len_s_stats = aliset.sufficient_statistics(ft)
pc_exp_llh = aliset(len_s_stats).double() * 5
print(pc_exp_llh)

In [None]:
log_alphas_1 = beer.HMM.baum_welch_forward(init_states, trans_mat, pc_exp_llh)
log_beta_1 = beer.HMM.baum_welch_backward(final_states, trans_mat, pc_exp_llh)
log_alphas_2, log_beta_2 = baum_welch_forward_backward_test(init_states, final_states, trans_mat, pc_exp_llh)
a1 = log_alphas_1.numpy()
a2 = log_alphas_2.numpy()
b1 = log_beta_1.numpy()
b2 = log_beta_2.numpy()

In [None]:
np.allclose(a1+b1, a2+b2)
print(a1+b1, '\n')
print(a2+b2)

In [None]:
c = (a1+b1).argsort(axis=1)
d = np.exp(a2+b2).argsort(axis=1)
c.shape

In [None]:
print(c, '\n')
print(d)