In [1]:
from functools import partial

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax
from jax import vmap
from dynamax.hidden_markov_model import CategoricalHMM
import Preprocessing as pp

In [2]:
X, y = pp.extract_features(
    '../../behavior data integrated/Bhv 7 - Ctrl/M2/Contingency Flip/FED000_071823_00.CSV'
    )

In [3]:
num_states = 2      # two types of dice (fair and loaded)
num_emissions = 1   # only one die is rolled at a time
num_classes = 6     # each die has six faces

initial_probs = jnp.array([0.5, 0.5])
transition_matrix = jnp.array([[0.95, 0.05], 
                               [0.10, 0.90]])
emission_probs = jnp.array([[1/6,  1/6,  1/6,  1/6,  1/6,  1/6],    # fair die
                            [1/10, 1/10, 1/10, 1/10, 1/10, 5/10]])  # loaded die


# Construct the HMM
hmm = CategoricalHMM(num_states, num_emissions, num_classes)

# Initialize the parameters struct with known values
params, _ = hmm.initialize(initial_probs=initial_probs,
                           transition_matrix=transition_matrix,
                           emission_probs=emission_probs.reshape(num_states, num_emissions, num_classes))

In [11]:
def print_params(params):
    jnp.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
    print("initial probs:")
    print(params.initial.probs)
    print("transition matrix:")
    print(params.transitions.transition_matrix)
    print("emission probs:")
    print(params.emissions.probs[:, 0, :]) # since num_emissions = 1

In [8]:
num_batches = 1
num_timesteps = 5000
hmm = CategoricalHMM(num_states, num_emissions, num_classes)

batch_states, batch_emissions = \
    vmap(partial(hmm.sample, params, num_timesteps=num_timesteps))(
        jr.split(jr.PRNGKey(42), num_batches))

In [9]:
print(batch_states[:, :5]) 
print(batch_emissions[:, :5]) 

[[1 1 0 0 0]]
[[[3]
  [5]
  [4]
  [0]
  [3]]]


In [13]:
key = jr.PRNGKey(0)
em_params, em_param_props = hmm.initialize(key)
print(em_param_props)

ParamsCategoricalHMM(initial=ParamsStandardHMMInitialState(probs=<dynamax.parameters.ParameterProperties object at 0x2957e5360>), transitions=ParamsStandardHMMTransitions(transition_matrix=<dynamax.parameters.ParameterProperties object at 0x29633d600>), emissions=ParamsCategoricalHMMEmissions(probs=<dynamax.parameters.ParameterProperties object at 0x29a0a59f0>))


In [None]:
em_params, log_probs = hmm.fit_em(em_params, 
                                  em_param_props, 
                                  batch_emissions, 
                                  num_iters=400)