In [1]:
import numpy as np; np.set_printoptions(linewidth = 150, suppress=True)
from copy import deepcopy
from scipy.spatial.distance import cdist
from generate import generate_TPM, generate_contexts_cues
from particle import Particle, resample_particles

# A modified COIN model
## Introduction

The COIN model was originally introduced by [Heald et. al 2023](https://doi.org/10.1038/s41586-021-04129-3) to study motor learning.

This project utilises a version that is simplified by removing latent states, and modified by making the cue multinomial.

### The modified COIN generative model
Contexts occur with frequency $\beta$.  
At each time step, the latent context variable evolves as a Markov process according to the transition probability matrix (TPM) $\Pi$.
$$
\Pi=
\begin{bmatrix}
p({c_{t}=1|c_{t-1}=1}) & p({c_{t}=2|c_{t-1}=1}) & ... \\
p({c_{t}=1|c_{t-1}=2}) & p({c_{t}=2|c_{t-1}=2}) & ... \\
... & ... & p({c_{t}=j|c_{t-1}=i})
\end{bmatrix}
$$
Each context is associated with a given row of a cue emission matrix (CEM)$\Phi$, 
$$\Phi=
\begin{bmatrix}
p(q_1=1|c=1) & p(q_2=1|c=1) & ... \\
p(q_1=1|c=2) & p(q_2=1|c=2) & ... \\
p(q_1=1|c=3) & p(q_2=1|c=3) & ... \\
... & ... & p(q_i=1|c=j) 
\end{bmatrix}
$$
such that at each time step, a binary cue vector (e.g. $q = \{ 1, 0, 0, 0, 1, 0\}$) corresponding to that context is emitted.

### Inference under the modified model

The goal of the learner is to compute the joint posterior $p(\Theta_t | q_{1:t})$ of quantities $\Theta_t = \{c_t, \beta, \Pi, \Phi\}$ that are not observed by the learner: the current context $c_t$, the global context frequencies $\beta$, the TPM $\Pi$, and the CEM $\Phi$. 

This is accomplished by using particle learning, which is detailed for the original model in section 2.3 of the [supplementary materials](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-04129-3/MediaObjects/41586_2021_4129_MOESM1_ESM.pdf). Here is a short outline of the process:

The essential state vector of each particle $z_t$ is composed of:
- context $c_t$
- the sufficient statistics $\Theta^s$ 
    - $n_t$ the context transition count matrix
    - $n_q$ the cue emission count matrix, with element $C_q[context = j,i]$ being the number of times the $i$-th element of the cue vector was 1 for context $j$
- the parameters $\Theta$ are composed of
    - $\hat{\beta}$ the estimated global context frequencies
    - $\hat{\Phi.}$ the estimated cue emission matrix

The distribution of essential state vectors is evolved over time with each cue by repeating the following steps:

1) #### Resampling
    First, particles are sampled with replacement according to weights $w_t$ proportional to the predictive distribution $w_t \propto \hat{p}(q_t | z_{t-1})$
    which can be decomposed into the expected local context transition probability and the local cue probability: $$\sum_{j=1}^{C+1}p(c_t = j | z_{t-1})p(q_t|c_t = j,z_{t-1})$$
    with $p(c_t|z_{t-1})$ being given by eq. S15 of the supplementary materials, and $p(q_t|c_t = j,z_{t-1})$ modified to be the probability of observing a binary cue vector, given each context specific row of the estimated CEM $\hat{\Phi}$
    $$ p(q_t|c_t = j,z_{t-1}) = \prod_{i=1}^{l} \hat{\Phi}[j,i]^{q_i} \cdot (1 - \hat{\Phi}[j,i])^{1 - q_i} $$

2) #### Propagation
    Next, the latent context variable $c_t$ is propagated conditioned on the predictive distribution $\hat{p}(q_t | z_{t-1})$, and the sufficient statistics $\Theta^s$ (context transition counts, context counts, cue observation counts) are incremented.

3) #### Parameter sampling
    To maintain diversity in parameters over the particles, the parameters $\beta$ and $\hat{\Phi}$ are resampled. $\beta$ is resampled according to eqs. S25-28 of the supplementary materials, while each entry of $\hat{\Phi}$ is sampled from:
    $$
    \hat{\Phi}[context = j, i] \sim Beta(a = n_q[j,i] + \hat{\Phi}[j, i], b = n_c[j] - n_q[j,i] + (1 - \hat{\Phi}[j, i]))
    $$

Finally, to estimate the CEM, the distribution of $\hat{\Phi}$ over particles is averaged, To estimate the full TPM $\Pi$, eq. S29 is used.

# Demo
#### Generate data 

In [227]:
cue_dim = 8         # Dimensionality of binary cue vector
n_contexts_true = 4 # Number of true contexts

# Set hyperparameters of the generative model
hyp_gamma = 5.0     # Controls the effective number of contexts, use larger values for more low-prob. contexts
hyp_alpha = 3.0     # Controls the resemblence of local transition prob.s to global transition prob.s
hyp_kappa = 1.5     # Controls the rate of self-transitions, use larger values for more self-transitions

# Generate the true TPM
true_TPM = generate_TPM(n_contexts_true, hyp_gamma, hyp_alpha, hyp_kappa)

# Specify the true CEM
true_CEM = [
    [1, 1, 0, 0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1, 1, 1, 1],
    [0, 0, 1, 1, 0, 0, 0, 0],
    [1, 0, 0, 0, 1, 1, 1, 1]
]
true_CEM = np.array(true_CEM, dtype = np.int64)

# Simulate data according to the generative model
t_steps = 2000
contexts, cues = generate_contexts_cues(true_TPM, true_CEM, t_steps)

# Print the context frequencies in the generated data
unique_elements, counts = np.unique(contexts, return_counts=True)
element_frequency = dict(zip(unique_elements, counts))
for element, frequency in element_frequency.items():
    print(f"Element {element}: Frequency {frequency}")

Global context probabilities:
[0.27604136 0.27577765 0.25739515 0.19078585]
True context transition probability matrix:
[[0.04207087 0.95434098 0.0035598  0.00002836]
 [0.22130562 0.2036342  0.0384161  0.53664408]
 [0.16695486 0.00529811 0.66396899 0.16377804]
 [0.13138131 0.14672764 0.72187145 0.0000196 ]]
Element 0: Frequency 281
Element 1: Frequency 428
Element 2: Frequency 903
Element 3: Frequency 388


#### Use particle learning to infer the TPM and CEM

In [230]:
# Specify the hyperparameters of the learner
hyp_lambda = 0.5    # Prior probability of a given cue vector element being 1
model_hyperparam = {
    'hyp_gamma' : hyp_gamma,
    'hyp_alpha' : hyp_alpha,
    'hyp_kappa' : hyp_kappa,
    'hyp_lambda': hyp_lambda,
    }

# Set up the ensemble of particles
particles = [Particle(cue_dim=cue_dim, n_contexts_init=0, hyperparam=model_hyperparam) for _ in range(100)]

# Learn using data
for i, cue in enumerate(cues):
    print(f'>> processing cue {i+1} of {len(cues)}     ', end = '\r')    
    
    # Calculate the posterior proportional, and sum over contexts to get the weights
    weights = []
    for particle in particles:
        particle.calculate_posterior_prop(cue)
        particle.calculate_weight()
        weights.append(particle.weight)
    
    particles = resample_particles(particles, weights)
    
    # Propagate the context and sufficient statistics
    for particle in particles:
        particle.propagate_context()
        particle.propagate_sufficient_statistics(cue)

    # Sample the parameters to generate diversity
    for particle in particles:
        particle.sample_parameters()

>> processing cue 2000 of 2000     

#### Compare the estimated and true TPM and CEM

In [231]:
# Estimate the CEM 
list_cue_emission_matrix = np.array([particle.get_theta_CEM() for particle in particles])
estimated_CEM = np.round(np.average(list_cue_emission_matrix, axis = 0), 2)

# Use the similarity between rows of the estimated and true CEM to reorder rows in the output data
distance_matrix = cdist(true_CEM, estimated_CEM, 'euclidean')
closest_indices = np.argmin(distance_matrix, axis=1)

# Compare the estimate and the true CEM
print("true_CEM:")
print(true_CEM)
print("Estimated CEM:")
print(estimated_CEM[closest_indices])

# Estimate the TPM
n_contexts_inferred = len(particles[0].get_ss_n_c())
exp_TPM = np.zeros((n_contexts_inferred, n_contexts_inferred))
for j in range(n_contexts_inferred):
    for k in range(n_contexts_inferred):
        for particle in particles:
            exp_TPM[j,k] += (hyp_alpha*particle.get_theta_beta_c()[k] + hyp_kappa*(j == k) + particle.get_ss_n_t()[j,k])/(hyp_alpha+hyp_kappa+particle.get_ss_n_c()[j])
exp_TPM /= len(particles)

# Compare the estimated and the true TPM
print("True TPM:")
print(np.round(true_TPM, 2))
print("Estimated TPM:")
print(np.round(exp_TPM[closest_indices][:,closest_indices], 2))

true_CEM:
[[1 1 0 0 0 0 1 1]
 [0 0 0 0 1 1 1 1]
 [0 0 1 1 0 0 0 0]
 [1 0 0 0 1 1 1 1]]
Estimated CEM:
[[1. 1. 0. 0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 1. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 1. 1. 1. 1.]]
True TPM:
[[0.04 0.95 0.   0.  ]
 [0.22 0.2  0.04 0.54]
 [0.17 0.01 0.66 0.16]
 [0.13 0.15 0.72 0.  ]]
Estimated TPM:
[[0.02 0.96 0.   0.  ]
 [0.2  0.21 0.03 0.56]
 [0.16 0.   0.67 0.16]
 [0.11 0.16 0.72 0.01]]
