In [194]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [195]:
from ssm.hmm import GaussianHMM
import jax.random as jr

In [196]:
hmm = GaussianHMM(num_states=3, num_emission_dims=5, seed=jr.PRNGKey(0))

In [197]:
states, emissions = hmm.sample(jr.PRNGKey(0), num_steps=100)

In [198]:
import numpy as onp

In [199]:
log_initial_probs = hmm._initial_condition.log_initial_probs(emissions)
log_transition_matrices = hmm._transitions.log_transition_matrices(emissions)
log_likelihoods = hmm._emissions.log_likelihoods(emissions)

natural_params = (log_initial_probs, log_transition_matrices, log_likelihoods)
natural_params_onp = (onp.array(log_initial_probs), onp.array(log_transition_matrices), onp.array(log_likelihoods))

In [200]:
from ssm.hmm.posterior import hmm_log_normalizer

log_normalizer, alphas = hmm_log_normalizer(log_initial_probs, log_transition_matrices, log_likelihoods)

In [227]:
import jax.numpy as np
import jax.scipy.special as spsp
from jax import lax

import chex

def normalize_vector(u):
    Z = np.sum(u)
    v = u / Z
    return v, Z
    
def log_normalize(log_u):
    log_Z = spsp.logsumexp(log_u, axis=1)
    log_v = log_u - log_Z
    return log_v, log_Z

def messages_forwards_normalized(init_state_distn, transition_matrix, log_likelihoods):
    T = log_likelihoods.shape[0]

    alpha_norm = np.empty_like(log_likelihoods)
    
    def f(carry, log_likelihood):  # (carry, x)
        in_potential, log_total = carry
        
        cmax = log_likelihood.max()
        alpha_norm_t = in_potential * np.exp(log_likelihood - cmax)
        norm = alpha_norm_t.sum()
        
        alpha_norm_t = alpha_norm_t / norm
        log_total += np.log(norm) + cmax
        
        out_potential = alpha_norm_t.dot(transition_matrix)
        
        return (out_potential, log_total), alpha_norm_t
        
    log_total = 0.
    in_potential = init_state_distn
        
    final_carry, alphas_norm = lax.scan(f, (in_potential, log_total), log_likelihoods)
    _, log_normalizer = final_carry
    
    return alphas_norm, log_normalizer

In [228]:
def sticky_transitions(num_states, stickiness=0.95):
    P = stickiness * np.eye(num_states) 
    P += (1 - stickiness) / (num_states - 1) * (1 - np.eye(num_states))
    return P

def random_args(num_timesteps, num_states, seed=0, offset=0, scale=1):
    pi = np.ones(num_states) / num_states
    P = sticky_transitions(num_states)
    log_likes = offset + scale * jr.normal(jr.PRNGKey(seed), (num_timesteps, num_states))
    return pi, P, log_likes

In [229]:
pi, P, log_likes = random_args(100, 3)

In [230]:
alphas, log_Z = messages_forwards_normalized(pi, P, log_likes)
alphas

DeviceArray([[2.87658632e-01, 3.63781899e-02, 6.75963163e-01],
             [8.71943355e-01, 5.09060770e-02, 7.71505535e-02],
             [4.34230834e-01, 7.67349452e-02, 4.89034176e-01],
             [4.41724569e-01, 8.32934305e-02, 4.74981993e-01],
             [7.70578921e-01, 1.01892382e-01, 1.27528712e-01],
             [7.91063726e-01, 1.14585310e-01, 9.43509564e-02],
             [8.29118609e-01, 1.41319511e-02, 1.56749383e-01],
             [7.91145265e-01, 2.05210112e-02, 1.88333735e-01],
             [9.93438601e-01, 9.45839973e-04, 5.61549747e-03],
             [8.61718297e-01, 6.57978430e-02, 7.24838525e-02],
             [6.67197227e-01, 1.05501324e-01, 2.27301463e-01],
             [7.21345544e-01, 1.01580642e-01, 1.77073762e-01],
             [4.24646199e-01, 3.86393070e-02, 5.36714494e-01],
             [5.72492659e-01, 2.99103539e-02, 3.97596985e-01],
             [2.65077233e-01, 6.74672127e-02, 6.67455494e-01],
             [2.77127363e-02, 5.52592389e-02, 9.1702801

In [204]:
log_Z_norm, alphas_norm = _stationary_hmm_log_normalizer_noramlized(*natural_params)
log_Z, alphas = _stationary_hmm_log_normalizer(*natural_params)

ValueError: axis 1 is out of bounds for array of dimension 1

In [189]:
Z = spsp.logsumexp(alphas, axis=1)
np.sum(np.exp(alphas - Z[:, None]), axis=1)

DeviceArray([0.99999994, 1.        , 1.0000001 , 1.        , 1.        ,
             0.99999994, 1.        , 1.0000001 , 1.        , 0.99999994,
             0.99999994, 1.        , 0.9999999 , 1.        , 1.0000001 ,
             0.9999999 , 1.        , 0.9999999 , 1.        , 1.        ,
             0.9999999 , 1.0000001 , 1.0000001 , 1.        , 0.9999999 ,
             1.        , 0.9999999 , 1.0000002 , 1.0000001 , 0.99999994,
             1.        , 0.99999976, 1.0000001 , 1.        , 0.9999999 ,
             0.99999976, 1.        , 1.0000001 , 1.0000002 , 1.0000001 ,
             0.99999976, 1.0000004 , 0.99999976, 1.0000004 , 0.9999997 ,
             1.0000002 , 1.0000005 , 1.        , 0.9999998 , 1.0000001 ,
             1.0000002 , 0.99999964, 0.99999964, 1.        , 0.99999994,
             1.0000002 , 0.9999996 , 0.99999976, 0.9999997 , 0.99999976,
             1.0000002 , 0.99999964, 1.        , 1.0000004 , 1.        ,
             0.99999976, 1.        , 1.0000004 , 0.