In [78]:
from hmm.hmm import HMM, sample_poisson_stimuli

import numpy as np

In [79]:
gamma = 0.1
beta = 0.2
alpha = 0.01
rates = [1, 20]

# This is uppercase-gamma.
transition_matrix = np.array(
    [[1 - gamma, 0, gamma], [0, 1 - gamma, gamma], [beta / 2, beta / 2, 1 - beta]]
)

In [80]:
hmm = HMM(transition_matrix, alpha, lambda z: sample_poisson_stimuli(z, rates), states=[0, 1, 2], rates=rates)

In [81]:
num_nodes = 8
time_steps = 50
initial_c = 2

In [82]:
true_processing_modes, true_focus, observations = hmm.forward(
    num_nodes,
    time_steps,
    initial_c,
)

Dimensions of the joint distribution will be (T - 1, num possible Cs at t, num possible Cs at t +1)

In [83]:
joint_probabilities_normalised = hmm.infer(observations)

In [84]:
joint_probabilities_normalised.shape

(49, 3, 3)

In [94]:
# Compute the marginal probabilities of C at each time step
marginal_prob_C = np.sum(joint_probabilities_normalised, axis=2)

# Calculate the estimated C at each time step
estimated_C = np.argmax(marginal_prob_C, axis=1)

# Compute the most likely Z given the estimated C
estimated_Z = np.zeros((time_steps, num_nodes), dtype=int)

for t, c in enumerate(estimated_C):
    estimated_Z[t] = hmm.sample_hidden_z(num_nodes, c)

In [109]:
correct_C = np.sum(np.equal(estimated_C, true_processing_modes[:-1])) / time_steps
correct_Z = np.sum(estimated_Z == true_focus) / (time_steps * num_nodes)

print(f"Proportion of correct C estimations: {correct_C:.2f}")
print(f"Proportion of correct Z estimations: {correct_Z:.2f}")

Proportion of correct C estimations: 0.60
Proportion of correct Z estimations: 0.79


In [107]:
estimated_Z

array([[0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0,

In [101]:
true_processing_modes

[2,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1]