In [186]:
from hmm.hmm import HMM, sample_poisson_stimuli

import numpy as np

In [187]:
training_data = np.genfromtxt("../../data/Ex_2.csv", delimiter="," ,dtype=int)[1:, 1:]

In [240]:
gamma = 0.1
beta = 0.2
alpha = 0.01
rates = [1, 20]

# This is uppercase-gamma.
transition_matrix = np.array(
    [[1 - gamma, 0, gamma], [0, 1 - gamma, gamma], [beta / 2, beta / 2, 1 - beta]]
)

In [189]:
hmm = HMM(transition_matrix, alpha, lambda z: sample_poisson_stimuli(z, rates), states=[0, 1, 2], rates=rates)

In [241]:
num_nodes = 8
time_steps = 50
initial_c = 2

In [191]:
true_processing_modes, true_focus, observations = hmm.forward(
    num_nodes,
    time_steps,
    initial_c,
)

Dimensions of the joint distribution will be (T - 1, num possible Cs at t, num possible Cs at t +1)

In [192]:
joint_probabilities_normalised = hmm.infer(observations)

In [193]:
joint_probabilities_normalised.shape

(49, 3, 3)

> ...it may be more convenient to test the implementation in other ways. You may, for instance, observe that
>
> $$P?(Z_{t,i}=z)-P(Z_{t,i}=z|X=x)$$
>
> has mean 0 and likewise for Ct. Using simulations you can compute such quantities, with $P(Z_{t,i}=z|X=x)$ computed by the inference algorithm, and empirically check if their averages across many replications of the simulations are zero.

In [252]:
num_simulations = 10
correct_C = 0
correct_Z = 0

diff_Z0 = []
diff_Z1 = []

for _ in range(num_simulations):
    true_processing_modes, true_focus, observations = hmm.forward(num_nodes, time_steps, initial_c)
    joint_probabilities_normalised = hmm.infer(observations)
    # Compute the marginal probabilities of C at each time step
    marginal_prob_C = np.sum(joint_probabilities_normalised, axis=2)

    # Calculate the estimated C at each time step
    estimated_C = np.argmax(marginal_prob_C, axis=1)

    # Compute the most likely Z given the estimated C
    estimated_Z = np.zeros((time_steps, num_nodes), dtype=int)

    for t, c in enumerate(estimated_C):
        estimated_Z[t] = hmm.sample_hidden_z(num_nodes, c)

    # Compute the accuracy
    correct_C += np.sum(true_processing_modes[:-1] == estimated_C) / time_steps
    correct_Z += np.sum(true_focus == estimated_Z) / (time_steps * num_nodes)

    # Calculate the marginal probability of Zt,i=0 given X=x
    marginal_prob_Z0_given_X = np.sum(joint_probabilities_normalised[:, :, :2], axis=2)
    # Calculate the marginal probability of Zt,i=1 given X=x
    marginal_prob_Z1_given_X = 1 - marginal_prob_Z0_given_X

    P_Z0 = np.mean(true_focus[:-1] == 0, axis=1)
    P_Z1 = np.mean(true_focus[:-1] == 1, axis=1)

    diff_Z0.append(P_Z0 - np.mean(marginal_prob_Z0_given_X, axis=1))
    diff_Z1.append(P_Z1 - np.mean(marginal_prob_Z1_given_X, axis=1))



In [254]:
correct_C /= num_simulations
correct_Z /= num_simulations

# Compute the mean differences across all simulations
mean_diff_Z0 = np.mean(np.array(diff_Z0))
mean_diff_Z1 = np.mean(np.array(diff_Z1))

print(f"Proportion of correct C estimations: {correct_C:.2f}")
print(f"Proportion of correct Z estimations: {correct_Z:.2f}")

print(f"Mean difference for Zt,i=0: {mean_diff_Z0:.5f}")
print(f"Mean difference for Zt,i=1: {mean_diff_Z1:.5f}")

Proportion of correct C estimations: 0.05
Proportion of correct Z estimations: 0.07
Mean difference for Zt,i=0: 0.22506
Mean difference for Zt,i=1: -0.22506
Mean difference for Zt,i=0 and Zt,i=1: 0.00000


We then run the implementation on the real data

In [196]:
joint_probabilities_normalised = hmm.infer(training_data)
time_steps = joint_probabilities_normalised.shape[0]

In [197]:
joint_probabilities_normalised.shape

(99, 3, 3)

In [198]:
# Compute the marginal probabilities of C at each time step
marginal_prob_C = np.sum(joint_probabilities_normalised, axis=2)

# Calculate the estimated C at each time step
estimated_C = np.argmax(marginal_prob_C, axis=1)

# Compute the most likely Z given the estimated C
estimated_Z = np.zeros((time_steps, num_nodes), dtype=int)

for t, c in enumerate(estimated_C):
    estimated_Z[t] = hmm.sample_hidden_z(num_nodes, c)

In [199]:
estimated_C

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [200]:
estimated_Z

array([[0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0,

In [230]:
joint_probabilities_normalised

array([[[-2.12588973e-09, -0.00000000e+00, -9.49733866e-06],
        [ 0.00000000e+00,  9.20777352e-01,  5.64038082e-03],
        [ 7.00233435e-08,  5.10681631e-02,  2.25235339e-02]],

       [[ 8.28474036e-10,  0.00000000e+00,  3.32784716e-06],
        [ 0.00000000e+00,  9.51682838e-01,  6.48442890e-03],
        [ 4.76014536e-08,  2.80624245e-02,  1.37669324e-02]],

       [[ 6.92784325e-10,  0.00000000e+00,  2.38420140e-06],
        [ 0.00000000e+00,  9.62653275e-01,  8.09971955e-03],
        [ 4.45253998e-08,  1.82117896e-02,  1.10327864e-02]],

       [[ 7.70210771e-10,  0.00000000e+00,  2.23149515e-06],
        [ 0.00000000e+00,  9.62177311e-01,  1.17158393e-02],
        [ 5.84592911e-08,  1.39098232e-02,  1.21947359e-02]],

       [[ 1.17031412e-09,  0.00000000e+00,  2.94398368e-06],
        [ 0.00000000e+00,  9.49686280e-01,  2.01742759e-02],
        [ 1.00609969e-07,  1.19139611e-02,  1.82224373e-02]],

       [[ 2.31076042e-09,  0.00000000e+00,  5.14621433e-06],
        [ 0.00