In [21]:
import jax.numpy as np
import jax.random as jr
import jax.scipy.special as spsp
from jax import value_and_grad

from ssm.hmm.posterior import _factorial_hmm_log_normalizer, hmm_log_normalizer

In [2]:
rng = jr.PRNGKey(0)
num_states = (3, 4)
num_timesteps = 10

log_initial_state_probs = np.zeros(num_states)
log_transition_matrices = tuple(
    jr.normal(key, (k, k)) for key, k in zip(jr.split(rng, len(num_states)), num_states))
log_transition_matrices = tuple(
    x - spsp.logsumexp(x, axis=1, keepdims=True)
    for x in log_transition_matrices
)

rng, this_rng = jr.split(rng, 2)
log_likelihoods = jr.normal(this_rng, (num_timesteps,) + num_states)



In [15]:
log_normalizer, filtered_potentials = \
    _factorial_hmm_log_normalizer(log_initial_state_probs,
                                  log_transition_matrices,
                                  log_likelihoods)



In [16]:
big_transition_matrix = np.kron(
    np.exp(log_transition_matrices[0]),
    np.exp(log_transition_matrices[1]),
)


In [17]:
log_normalizer2, filtered_potentials2 = \
    hmm_log_normalizer(log_initial_state_probs.reshape(-1),
                       np.log(big_transition_matrix),
                       log_likelihoods.reshape(num_timesteps, -1))


In [18]:
log_normalizer, log_normalizer2

(DeviceArray(5.865858, dtype=float32), DeviceArray(5.865858, dtype=float32))

In [19]:
np.allclose(filtered_potentials.reshape(num_timesteps, -1), filtered_potentials2)

DeviceArray(True, dtype=bool)

# Test the gradient

In [22]:
# _factorial_hmm_log_normalizer(log_initial_state_probs,
#                                   log_transition_matrices,
#                                   log_likelihoods)

f = value_and_grad(_factorial_hmm_log_normalizer, argnums=(1, 2), has_aux=True)
(log_normalizer, filtered_potentials), (expected_transitions, expected_states) = \
    f(log_initial_state_probs, log_transition_matrices, log_likelihoods)

# Joint distribution for the transitions

In [31]:
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

In [43]:
# In TFP, we can write this as:
prev_states = tuple(1 for _ in num_states)

    Root = tfd.JointDistributionCoroutine.Root  # Convenient alias.
def model():
    for prev_state, log_transition_matrix in zip(prev_states, log_transition_matrices):
        yield Root(tfd.Categorical(logits=log_transition_matrix[prev_state]))
        
joint = tfd.JointDistributionCoroutine(model)
next_states = tuple(joint.sample(seed=jr.PRNGKey(1), sample_shape=(100,)))
joint.log_prob(next_states)
# next_states

DeviceArray([-1.1101403, -1.1101403, -1.1101403, -1.587949 , -2.3004153,
             -2.5072618, -2.778224 , -1.1101403, -3.746516 , -1.1101403,
             -2.5072618, -1.587949 , -2.029453 , -1.587949 , -4.130764 ,
             -1.587949 , -1.587949 , -1.1101403, -1.587949 , -2.5072618,
             -1.587949 , -2.029453 , -1.587949 , -1.1101403, -2.3004153,
             -1.1101403, -2.5072618, -2.778224 , -1.1101403, -1.1101403,
             -1.587949 , -1.587949 , -1.1101403, -2.029453 , -2.3004153,
             -1.587949 , -2.3004153, -2.3004153, -2.5072618, -2.5072618,
             -2.029453 , -1.587949 , -1.1101403, -1.587949 , -1.1101403,
             -2.5072618, -1.1101403, -1.1101403, -1.587949 , -2.5072618,
             -1.1101403, -1.1101403, -1.1101403, -1.1101403, -2.3004153,
             -1.1101403, -2.3004153, -2.778224 , -1.1101403, -1.587949 ,
             -1.1101403, -1.587949 , -2.3004153, -1.1101403, -1.1101403,
             -1.1101403, -1.1101403, -1.1101403, -1