In [2]:
import numpy as np; np.set_printoptions(linewidth = 150, suppress=True)
from copy import deepcopy
from scipy.spatial.distance import cdist

In [3]:
from generate import generate_TPM, generate_contexts_cues
from particle import Particle, resample

In [40]:
# Test the function with example parameters
input_dim = 6

n_contexts_init = 4

h_gamma_c = 1.0
h_alpha_c = 5.0
h_kappa_c = 2.0

true_TPM = generate_TPM(input_dim, n_contexts_init, h_gamma_c, h_alpha_c, h_kappa_c)

true_CEM = [
    [1, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 1],
    [0, 0, 1, 1, 0, 0],
    [1, 0, 0, 0, 0, 1]
]

true_CEM = np.array(true_CEM, dtype = np.int64)

t_steps = 1000
contexts, cues = generate_contexts_cues(true_TPM, true_CEM, t_steps)


# Print the context frequencies in the generated data
unique_elements, counts = np.unique(contexts, return_counts=True)
element_frequency = dict(zip(unique_elements, counts))
for element, frequency in element_frequency.items():
    print(f"Element {element}: Frequency {frequency}")

Global context probabilities:
[0.53071743 0.04051346 0.21692036 0.21184875]
True context transition probability matrix:
[[0.86239456 0.10909311 0.02850528 0.00000705]
 [0.11357978 0.33306852 0.33658356 0.21676815]
 [0.08137829 0.         0.75657932 0.1620424 ]
 [0.23946346 0.         0.00004106 0.76049548]]
Element 0: Frequency 466
Element 1: Frequency 105
Element 2: Frequency 187
Element 3: Frequency 242


In [41]:
# Specify the hyperparameters
model_hyperparam = {
    'h_gamma_c' : h_gamma_c,
    'h_alpha_c' : h_alpha_c,
    'h_kappa_c' : h_kappa_c
    }

# Set up the ensemble of particles
particles = [Particle(input_dim=input_dim, n_contexts_init=0, hyperparam=model_hyperparam) for _ in range(20)]

# Learn using data
for i, cue in enumerate(cues):
    print(f'>> processing cue {i} of {len(cues)}     ', end = '\r')    
    
    # Calculate the responsibility, and use this to resample the particles
    weights = []
    for particle in particles:
        particle.calculate_joint(cue)
        particle.calculate_responsibility()
        weights.append(particle.get_responsibility())
    
    particles = resample(particles, weights)
    
    # Propagate the context and sufficient statistics
    for particle in particles:
        particle.propagate_context()
        particle.propagate_sufficient_statistics(cue)

    # Sample the parameters to generate diversity
    for particle in particles:
        particle.sample_parameters()

>> processing cue 999 of 1000     

In [42]:
# Estimate the CEM 
list_cue_emission_matrix = np.array([particle.get_cue_emission_matrix() for particle in particles])
estimated_CEM = np.round(np.average(list_cue_emission_matrix, axis = 0), 1)

# Use the similarity between rows of the estimated and true CEM to reorder rows in the output data
distance_matrix = cdist(true_CEM, estimated_CEM, 'euclidean')
closest_indices = np.argmin(distance_matrix, axis=1)

# Compare the estimate and the true CEM
print("true_CEM:")
print(true_CEM)
print("Estimated CEM:")
print(estimated_CEM[closest_indices])

# Estimate the TPM
n_contexts_inferred = len(particles[0].get_context_counts())
exp_TPM = np.zeros((n_contexts_inferred, n_contexts_inferred))
for j in range(n_contexts_inferred):
    for k in range(n_contexts_inferred):
        for particle in particles:
            exp_TPM[j,k] += (h_alpha_c*particle.get_global_context_prob()[k] + h_kappa_c*(j == k) + particle.get_transition_counts()[j,k])/(h_alpha_c+h_kappa_c+particle.get_context_counts()[j])
exp_TPM /= len(particles)

# Compare the estimated and the true TPM
print("True TPM:")
print(np.round(true_TPM, 2))
print("Estimated TPM:")
print(np.round(exp_TPM[closest_indices][:,closest_indices], 2))

true_CEM:
[[1 1 0 0 0 0]
 [0 0 0 0 1 1]
 [0 0 1 1 0 0]
 [1 0 0 0 0 1]]
Estimated CEM:
[[1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1.]
 [0. 0. 1. 1. 0. 0.]
 [1. 0. 0. 0. 0. 1.]]
True TPM:
[[0.86 0.11 0.03 0.  ]
 [0.11 0.33 0.34 0.22]
 [0.08 0.   0.76 0.16]
 [0.24 0.   0.   0.76]]
Estimated TPM:
[[0.81 0.15 0.04 0.  ]
 [0.12 0.34 0.36 0.15]
 [0.08 0.01 0.7  0.2 ]
 [0.23 0.   0.   0.75]]
