In [39]:
import numpy as np

def forward(
    initial_conditions: list[np.ndarray],
    continuous_state_transitions: list[list[np.ndarray]],
    discrete_state_transitions: np.ndarray,
    likelihoods: list[list[np.ndarray]]
) -> tuple[list[list[np.ndarray]], float]:
    
    n_time = len(likelihoods)
    n_states = len(initial_conditions)
    posterior = [[] for _ in range(n_time)]
    
    assert n_states == discrete_state_transitions.shape[0]
    assert n_states == len(likelihoods[0])
    
    data_log_likelihood = 0.0
    
    # Initial step
    time_ind = 0
    
    scaling = 0.0
    for ic, likelihood in zip(initial_conditions, likelihoods[time_ind]):
        posterior[time_ind].append(ic * likelihood)
        scaling += np.sum(posterior[time_ind][-1])
    
    # Scale posterior
    posterior[time_ind] = [post / scaling for post in posterior[time_ind]]
    data_log_likelihood += scaling
    
    # Rest of time
    for time_ind in range(1, n_time):
        scaling = 0.0
        
        posterior[time_ind-1]
        posterior[time_ind]
        
        for prev_state in range(n_states):
            pass
        
        
        # Scale posterior
        posterior[time_ind] = [post / scaling for post in posterior[time_ind]]
        data_log_likelihood += scaling

    return posterior, data_log_likelihood
    

In [40]:
# local, no spike, continuous, fragmented
initial_conditions = [np.ones((1,)), np.zeros((1,)), np.zeros((5,)), np.zeros((5,))]
continuous_state_transitions = [
    [np.identity(5), np.zeros((5, 5))],
    [np.zeros((5, 5)), np.identity(5)],
]
discrete_state_transitions = np.identity(4)
likelihoods = [
    [np.ones((1,)), np.zeros((1,)), np.zeros((5,)), np.zeros((5,))],
    [np.ones((1,)), np.zeros((1,)), np.zeros((5,)), np.zeros((5,))],
]

forward(
    initial_conditions,
    continuous_state_transitions,
    discrete_state_transitions,
    likelihoods,
)


([[array([1.]),
   array([0.]),
   array([0., 0., 0., 0., 0.]),
   array([0., 0., 0., 0., 0.])],
  []],
 1.0)

In [95]:
def forward2(
    initial_conditions: np.ndarray,
    continuous_state_transitions: np.ndarray,
    discrete_state_transitions: np.ndarray,
    likelihoods: np.ndarray,
    state_ind: np.ndarray,
) -> tuple[np.ndarray, float]:
    """Flat state and position bins"""
    discrete_state_transitions_per_bin = discrete_state_transitions[state_ind[:, np.newaxis], state_ind]


In [171]:
from scipy.stats import multivariate_normal


# local, no spike, continuous, fragmented

n_time = 2
n_states = 4
bin_sizes = [1, 1, 5, 5]

state_ind = np.concatenate([ind * np.ones((bin_size,), dtype=int) for ind, bin_size in enumerate(bin_sizes)])
n_state_bins = len(state_ind)

initial_conditions = np.zeros((n_state_bins,))
initial_conditions[state_ind == 0] = 1.0

continuous_state_transitions = np.zeros((n_state_bins, n_state_bins))
discrete_state_transitions = np.asarray([[0.9, 0.1/3, 0.1/3, 0.1/3],
                                         [0.1/3, 0.9, 0.1/3, 0.1/3],
                                         [0.1/3, 0.1/3, 0.9, 0.1/3],
                                         [0.1/3, 0.1/3,0.1/3, 0.9],
                                         ])
discrete_state_transitions_per_bin = discrete_state_transitions[state_ind[:, np.newaxis], state_ind]



continuous_state_transitions = np.zeros((n_state_bins, n_state_bins))
from_state_ind, to_state_ind = np.meshgrid(state_ind, state_ind)
from_state_ind, to_state_ind = from_state_ind.T, to_state_ind.T

for from_state in range(n_states):
    for to_state in range(n_states):
        
        inds = (from_state_ind == from_state) & (to_state_ind == to_state)
        
        if (bin_sizes[from_state] == 1) & (bin_sizes[to_state] == 1):
            # transition from discrete to discrete
            continuous_state_transitions[inds] = 1.0
        elif (bin_sizes[from_state] > 1) & (bin_sizes[to_state] == 1):
            # transition from continuous to discrete
            continuous_state_transitions[inds] = 1.0
        elif (bin_sizes[from_state] == 1) & (bin_sizes[to_state] > 1):
            # transition from discrete to continuous
            continuous_state_transitions[inds] = 1.0 / bin_sizes[to_state] # uniform
        else:
            # transition from continuous to continuous
            if from_state != to_state:
                continuous_state_transitions[inds] = 1.0 / bin_sizes[to_state] # uniform
            else:
                transition_matrix = np.stack(
                    [
                        multivariate_normal(
                            mean=center, cov=6.0
                        ).pdf(np.arange(bin_sizes[to_state]))
                        for center in np.arange(bin_sizes[to_state])
                    ],
                    axis=1,
                )
                continuous_state_transitions[np.nonzero(inds)] = transition_matrix.ravel()
                
likelihoods = np.ones((n_time, n_state_bins))

forward2(initial_conditions, continuous_state_transitions, discrete_state_transitions, likelihoods, state_ind)

In [175]:
(discrete_state_transitions_per_bin * continuous_state_transitions) @ (initial_conditions * likelihoods[0])

array([0.9       , 0.03333333, 0.03333333, 0.03333333, 0.03333333,
       0.03333333, 0.03333333, 0.03333333, 0.03333333, 0.03333333,
       0.03333333, 0.03333333])

In [99]:
x, y = np.meshgrid(state_ind, state_ind)
for x_row, y_row in zip(x, y):
    print([f"({ind2}->{ind1})" for ind1, ind2 in zip(x_row, y_row)])

['(0->0)', '(0->1)', '(0->2)', '(0->2)', '(0->2)', '(0->2)', '(0->2)', '(0->3)', '(0->3)', '(0->3)', '(0->3)', '(0->3)', '(0->3)']
['(1->0)', '(1->1)', '(1->2)', '(1->2)', '(1->2)', '(1->2)', '(1->2)', '(1->3)', '(1->3)', '(1->3)', '(1->3)', '(1->3)', '(1->3)']
['(2->0)', '(2->1)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)']
['(2->0)', '(2->1)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)']
['(2->0)', '(2->1)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)']
['(2->0)', '(2->1)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)']
['(2->0)', '(2->1)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->2)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)', '(2->3)']
['(3->0)', '(3->1)', '(3->2)', '(3->2)', '(3->2)', '(3->2)', '(3->2)', '(3->3)', '(

In [131]:
bin_ind = np.arange(16).reshape(4, 4)
bin_ind = bin_ind[*np.meshgrid(state_ind, state_ind)].T
bin_ind

array([[ 0,  1,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3],
       [ 4,  5,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15]])

In [140]:
# discrete -> discrete
continuous_state_transitions = np.zeros((n_state_bins, n_state_bins))
continuous_state_transitions[np.isin(bin_ind, [0, 1, 4, 5])] = 1.0
continuous_state_transitions

array([[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [123]:
bin_ind = np.arange(16).reshape(4, 4)
bin_ind[state_ind[:, np.newaxis], state_ind]

array([[ 0,  1,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3],
       [ 4,  5,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [ 8,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15],
       [12, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15]])

In [90]:
x

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

In [91]:
state_ind

array([0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])

In [116]:
from_state_ind, to_state_ind = np.meshgrid(state_ind, state_ind)
from_state_ind, to_state_ind = from_state_ind.T, to_state_ind.T
from_state_ind, to_state_ind

(array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]),
 array([[0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
        [0, 1, 2, 2, 2, 2, 2, 3,