In [2]:
import numpy as np; np.set_printoptions(linewidth = 150, suppress=True)
from copy import deepcopy
from scipy.spatial.distance import cdist
from generate import generate_TPM, generate_contexts_cues
from particle import Particle, resample

# A modified COIN model
## Introduction

The COIN model was originally introduced by [Heald et. al 2023](https://doi.org/10.1038/s41586-021-04129-3) to study motor learning.

This project utilises a version that is simplified by removing latent states, and modified by making the cue multinomial.

### The modified COIN generative model
Contexts occur with frequency $\beta$.  
At each time step, the latent context variable evolves as a Markov process according to the transition probability matrix (TPM) $\Pi$.
$$
\Pi=
\begin{bmatrix}
p({c_{t}=1|c_{t-1}=1}) & p({c_{t}=2|c_{t-1}=1}) & ... \\
p({c_{t}=1|c_{t-1}=2}) & p({c_{t}=2|c_{t-1}=2}) & ... \\
... & ... & p({c_{t}=j|c_{t-1}=i})
\end{bmatrix}
$$
Each context is associated with a given row of a cue emission matrix (CEM)$\Phi$, 
$$\Phi=
\begin{bmatrix}
p(q_1=1|c=1) & p(q_2=1|c=1) & ... \\
p(q_1=1|c=2) & p(q_2=1|c=2) & ... \\
p(q_1=1|c=3) & p(q_2=1|c=3) & ... \\
... & ... & p(q_i=1|c=j) 
\end{bmatrix}
$$
such that at each time step, a binary cue vector (e.g. $q = \{ 1, 0, 0, 0, 1, 0\}$) corresponding to that context is emitted.

### Inference under the modified model

The goal of the learner is to compute the joint posterior $p(\Theta_t | q_{1:t})$ of quantities $\Theta_t = \{c_t, \beta, \Pi, \Phi\}$ that are not observed by the learner: the current context $c_t$, the global context frequencies $\beta$, the TPM $\Pi$, and the CEM $\Phi$. 

This is accomplished by using particle learning, which is detailed for the original model in section 2.3 of the [supplementary materials](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-04129-3/MediaObjects/41586_2021_4129_MOESM1_ESM.pdf). Here is a short outline of the process:

The essential state vector of each particle $z_t$ is composed of:
- context $c_t$
- the sufficient statistics $\Theta^s$ 
    - $C_c$ the count of how many times each context was observed
    - $C_t$ the context transition count matrix
    - $C_q$ the cue emission count matrix, with element $C_q[context = j,i]$ being the number of times the $i$-th element of the cue vector was 1 for context $j$
- the parameters $\Theta$ are composed of
    - $\hat{\beta}$ the estimated global context frequencies
    - $\hat{\Phi.}$ the estimated cue emission matrix

The distribution of essential state vectors is evolved over time with each cue by repeating the following steps:

1) #### Resampling
    First, particles are sampled with replacement according to weights $w_t$ proportional to the predictive distribution $w_t \propto \hat{p}(q_t | z_{t-1})$
    which can be decomposed into the expected local context transition probability and the local cue probability: $$\sum_{j=1}^{C+1}p(c_t = j | z_{t-1})p(q_t|c_t = j,z_{t-1})$$
    with $p(c_t|z_{t-1})$ being given by eq. S15 of the supplementary materials, and $p(q_t|c_t = j,z_{t-1})$ modified to be the probability of observing a binary cue vector, given each context specific row of the estimated CEM $\hat{\Phi}$
    $$ p(q_t|c_t = j,z_{t-1}) = \prod_{i=1}^{l} \hat{\Phi}[j,i]^{q_i} \cdot (1 - \hat{\Phi}[j,i])^{1 - q_i} $$

2) #### Propagation
    Next, the latent context variable $c_t$ is propagated conditioned on the predictive distribution $\hat{p}(q_t | z_{t-1})$, and the sufficient statistics $\Theta^s$ (context transition counts, context counts, cue observation counts) are incremented.

3) #### Parameter sampling
    To maintain diversity in parameters over the particles, the parameters $\beta$ and $\hat{\Phi}$ are resampled. $\beta$ is resampled according to eqs. S25-28 of the supplementary materials, while each entry of $\hat{\Phi}$ is sampled from:
    $$
    \hat{\Phi}[context = j, i] \sim Beta(a = C_q[j,i] + \hat{\Phi}[j, i], b = C_c[j] - C_q[j,i] + (1 - \hat{\Phi}[j, i]))
    $$

Finally, to estimate the CEM, the distribution of \hat{\Phi} over particles is averaged, To estimate the full TPM $\Pi$, eq. S29 is used.

# Demo
#### Generate data 

In [40]:
input_dim = 6 # Dimensionality of binary cue vector
n_contexts_true = 4 # Number of true contexts

# Set hyperparameters of the generative model
h_gamma_c = 1.0
h_alpha_c = 5.0
h_kappa_c = 2.0

# Generate the true TPM
true_TPM = generate_TPM(input_dim, n_contexts_true, h_gamma_c, h_alpha_c, h_kappa_c)

# Specify the true CEM
true_CEM = [
    [1, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 1],
    [0, 0, 1, 1, 0, 0],
    [1, 0, 0, 0, 0, 1]
]
true_CEM = np.array(true_CEM, dtype = np.int64)

# Simulate data according to the generative model
t_steps = 1000
contexts, cues = generate_contexts_cues(true_TPM, true_CEM, t_steps)

# Print the context frequencies in the generated data
unique_elements, counts = np.unique(contexts, return_counts=True)
element_frequency = dict(zip(unique_elements, counts))
for element, frequency in element_frequency.items():
    print(f"Element {element}: Frequency {frequency}")

Global context probabilities:
[0.53071743 0.04051346 0.21692036 0.21184875]
True context transition probability matrix:
[[0.86239456 0.10909311 0.02850528 0.00000705]
 [0.11357978 0.33306852 0.33658356 0.21676815]
 [0.08137829 0.         0.75657932 0.1620424 ]
 [0.23946346 0.         0.00004106 0.76049548]]
Element 0: Frequency 466
Element 1: Frequency 105
Element 2: Frequency 187
Element 3: Frequency 242


#### Use particle learning to infer the TPM and CEM

In [41]:
# Specify the hyperparameters of the learner
model_hyperparam = {
    'h_gamma_c' : h_gamma_c,
    'h_alpha_c' : h_alpha_c,
    'h_kappa_c' : h_kappa_c
    }

# Set up the ensemble of particles
particles = [Particle(input_dim=input_dim, n_contexts_init=0, hyperparam=model_hyperparam) for _ in range(20)]

# Learn using data
for i, cue in enumerate(cues):
    print(f'>> processing cue {i} of {len(cues)}     ', end = '\r')    
    
    # Calculate the responsibility, and use this to resample the particles
    weights = []
    for particle in particles:
        particle.calculate_joint(cue)
        particle.calculate_responsibility()
        weights.append(particle.get_responsibility())
    
    particles = resample(particles, weights)
    
    # Propagate the context and sufficient statistics
    for particle in particles:
        particle.propagate_context()
        particle.propagate_sufficient_statistics(cue)

    # Sample the parameters to generate diversity
    for particle in particles:
        particle.sample_parameters()

>> processing cue 999 of 1000     

#### Compare the estimated and true TPM and CEM

In [42]:
# Estimate the CEM 
list_cue_emission_matrix = np.array([particle.get_cue_emission_matrix() for particle in particles])
estimated_CEM = np.round(np.average(list_cue_emission_matrix, axis = 0), 1)

# Use the similarity between rows of the estimated and true CEM to reorder rows in the output data
distance_matrix = cdist(true_CEM, estimated_CEM, 'euclidean')
closest_indices = np.argmin(distance_matrix, axis=1)

# Compare the estimate and the true CEM
print("true_CEM:")
print(true_CEM)
print("Estimated CEM:")
print(estimated_CEM[closest_indices])

# Estimate the TPM
n_contexts_inferred = len(particles[0].get_context_counts())
exp_TPM = np.zeros((n_contexts_inferred, n_contexts_inferred))
for j in range(n_contexts_inferred):
    for k in range(n_contexts_inferred):
        for particle in particles:
            exp_TPM[j,k] += (h_alpha_c*particle.get_global_context_prob()[k] + h_kappa_c*(j == k) + particle.get_transition_counts()[j,k])/(h_alpha_c+h_kappa_c+particle.get_context_counts()[j])
exp_TPM /= len(particles)

# Compare the estimated and the true TPM
print("True TPM:")
print(np.round(true_TPM, 2))
print("Estimated TPM:")
print(np.round(exp_TPM[closest_indices][:,closest_indices], 2))

true_CEM:
[[1 1 0 0 0 0]
 [0 0 0 0 1 1]
 [0 0 1 1 0 0]
 [1 0 0 0 0 1]]
Estimated CEM:
[[1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1.]
 [0. 0. 1. 1. 0. 0.]
 [1. 0. 0. 0. 0. 1.]]
True TPM:
[[0.86 0.11 0.03 0.  ]
 [0.11 0.33 0.34 0.22]
 [0.08 0.   0.76 0.16]
 [0.24 0.   0.   0.76]]
Estimated TPM:
[[0.81 0.15 0.04 0.  ]
 [0.12 0.34 0.36 0.15]
 [0.08 0.01 0.7  0.2 ]
 [0.23 0.   0.   0.75]]
