In [5]:
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jax import vmap
from jax import jit
from functools import partial

def _normalize(u, axis=0, eps=1e-15):
    """Normalizes the values within the axis in a way that they sum up to 1.
    Args:
        u: Input array to normalize.
        axis: Axis over which to normalize.
        eps: Minimum value threshold for numerical stability.
    Returns:
        Tuple of the normalized values, and the normalizing denominator.
    """
    u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
    c = u.sum(axis=axis)
    c = jnp.where(c == 0, 1, c)
    return u / c, c


# Helper functions for the two key filtering steps
def _condition_on(probs, ll):
    """Condition on new emissions, given in the form of log likelihoods
    for each discrete state, while avoiding numerical underflow.
    Args:
        probs(k): prior for state k
        ll(k): log likelihood for state k
    Returns:
        probs(k): posterior for state k
    """
    ll_max = ll.max()
    new_probs = probs * jnp.exp(ll - ll_max)
    new_probs, norm = _normalize(new_probs)
    log_norm = jnp.log(norm) + ll_max
    return new_probs, log_norm


def _predict(probs, A):
    return A.T @ probs


@partial(jit)
def hmm_filter(
    initial_distribution,
    transition_matrix,
    log_likelihoods,
):
    r"""Forwards filtering
    Transition matrix may be either 2D (if transition probabilities are fixed) or 3D
    if the transition probabilities vary over time. Alternatively, the transition
    matrix may be specified via `transition_fn`, which takes in a time index $t$ and
    returns a transition matrix.
    Args:
        initial_distribution: $p(z_1 \mid u_1, \theta)$
        transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
        log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
    Returns:
        filtered posterior distribution
    """
    num_timesteps, num_states = log_likelihoods.shape

    def _step(carry, t):
        log_normalizer, predicted_probs = carry

        ll = log_likelihoods[t]

        filtered_probs, log_norm = _condition_on(predicted_probs, ll)
        log_normalizer += log_norm
        predicted_probs_next = _predict(filtered_probs, transition_matrix)

        return (log_normalizer, predicted_probs_next), (filtered_probs, predicted_probs)

    carry = (0.0, initial_distribution)
    (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(_step, carry, jnp.arange(num_timesteps))

    return log_normalizer, filtered_probs, predicted_probs

In [6]:
import numpy as np
from src.hmm3 import forward

initial_distribution = np.asarray([0.9, 0.1])
transition_matrix = np.asarray([[0.9, 0.1],
                                 [0.1, 0.9]])
log_likelihoods = np.zeros((9, 2))

marginal_likelihood, _, _ = hmm_filter(
    initial_distribution,
    transition_matrix,
    log_likelihoods,
)
marginal_likelihood

Array(-1.1920929e-07, dtype=float32)

In [212]:
from jax.nn import softmax


def centered_softmax_forward(y):
    """`softmax(x) = exp(x-c) / sum(exp(x-c))` where c is the last coordinate
    
    Example
    -------
    > y = np.log([2, 3, 4])
    > np.allclose(centered_softmax_forward(y), [0.2, 0.3, 0.4, 0.1])
    """
    if y.ndim == 1:
        y = jnp.append(y, 0)
    else:
        y = jnp.column_stack((y,  jnp.zeros((y.shape[0],))))

    return softmax(y, axis=-1)


def centered_softmax_inverse(y):
    """`softmax(x) = exp(x-c) / sum(exp(x-c))` where c is the last coordinate
    
    Example
    -------
    > y = np.asarray([0.2, 0.3, 0.4, 0.1])
    > np.allclose(np.exp(centered_softmax_inverse(y)), np.asarray([2,3,4]))
    """
    return jnp.log(y[..., :-1]) - jnp.log(y[..., -1])

In [244]:
from scipy.optimize import minimize
import jax
from src.simulate import simulate_two_state_poisson


sampling_frequency = 500
time, true_rate, spikes = simulate_two_state_poisson(
    sampling_frequency=sampling_frequency
)

n_states = 2
n_time = spikes.shape[0]
initial_distribution = np.ones((n_states,)) / n_states
transition_matrix = np.asarray([[0.98, 0.02], [0.02, 0.98]])
is_training = np.ones((n_time,), dtype=bool)
is_training[: n_time // 2] = False

n_rate_parameters = 1
n_rates = n_states * n_rate_parameters

@jax.jit
def neglogp(log_parameters):
    unconstrained_rates = log_parameters[:n_rates]
    unconstrained_initial_distribution = log_parameters[n_rates:n_rates+n_states + 1]
    unconstrained_transition_matrix = log_parameters[n_rates+n_states + 1:]
    
    rates = jnp.exp(unconstrained_rates)
    initial_distribution = centered_softmax_inverse(unconstrained_initial_distribution)
    transition_matrix = centered_softmax_inverse(unconstrained_transition_matrix.reshape((n_states, n_states + 1), order="F"))
    
    likelihood = jax.scipy.stats.poisson.logpmf(spikes[:, jnp.newaxis], rates)

    marginal_likelihood, _, _ = hmm_filter(
        initial_distribution, transition_matrix, log_likelihoods
    )
    
    return -marginal_likelihood

dlike = jax.grad(neglogp)

x0 = np.concatenate((np.log([spikes[is_training].mean(), spikes[~is_training].mean()]), centered_softmax_forward(initial_distribution), centered_softmax_forward(transition_matrix).ravel()))

res = minimize(
    neglogp, x0=x0, method="BFGS", jac=dlike,
)

log_parameters = res.x
unconstrained_rates = log_parameters[:n_rates]
unconstrained_initial_distribution = log_parameters[n_rates:n_rates+n_states + 1]
unconstrained_transition_matrix = log_parameters[n_rates+n_states + 1:]
estimated_rates = jnp.exp(unconstrained_rates)
estimated_initial_distribution = centered_softmax_inverse(unconstrained_initial_distribution)
estimated_transition_matrix = centered_softmax_inverse(unconstrained_transition_matrix.reshape((n_states, n_states + 1)))


In [247]:
estimated_initial_distribution

Array([nan, nan], dtype=float32)

In [245]:
estimated_rates * sampling_frequency, np.unique(true_rate)

(Array([14.150001, 10.500001], dtype=float32), array([ 5., 20.]))

In [219]:
unconstrained_rates = log_parameters[:n_rates]
unconstrained_initial_distribution = log_parameters[n_rates:n_rates+n_states]
unconstrained_transition_matrix = log_parameters[n_rates+n_states:]

jnp.exp(unconstrained_rates)
centered_softmax_forward(unconstrained_initial_distribution)
centered_softmax_forward(unconstrained_transition_matrix).reshape((n_states, n_states))

Array([[0.3310811 , 0.3310811 ],
       [0.00675676, 0.3310811 ]], dtype=float32)

In [223]:
unconstrained_rates

array([-3.3639016 , -3.92713664])

In [227]:
np.allclose(np.exp(x0[:2]), jnp.exp(log_parameters[:n_rates]))

True

In [230]:
 centered_softmax_forward(unconstrained_initial_distribution)

Array([0.01960784, 0.9607844 , 0.01960784], dtype=float32)

In [20]:
np.exp(x0) * sampling_frequency

array([14.7,  9.5])

In [236]:
np.allclose(centered_softmax_inverse(centered_softmax_forward(transition_matrix)), transition_matrix)

True

In [204]:
def centered_softmax_inverse(y):
    """`softmax(x) = exp(x-c) / sum(exp(x-c))` where c is the last coordinate
    
    Example
    -------
    > y = np.asarray([0.2, 0.3, 0.4, 0.1])
    > np.allclose(np.exp(centered_softmax_inverse(y)), np.asarray([2,3,4]))
    """
    return jnp.log(y[..., :-1]) - jnp.log(y[..., -1])

y = np.asarray([0.2, 0.3, 0.4, 0.1])
np.allclose(np.exp(centered_softmax_inverse(y)), np.asarray([2,3,4]))

True

In [203]:
from jax.nn import softmax


def centered_softmax_forward(y):
    """`softmax(x) = exp(x-c) / sum(exp(x-c))` where c is the last coordinate
    
    Example
    -------
    > y = np.log([2, 3, 4])
    > np.allclose(centered_softmax_forward(y), [0.2, 0.3, 0.4, 0.1])
    """
    if y.ndim == 1:
        y = jnp.append(y, 0)
    else:
        y = jnp.column_stack((y,  jnp.zeros((y.shape[0],))))

    return softmax(y, axis=-1)

y = np.log([2, 3, 4])
np.allclose(centered_softmax_forward(y), [0.2, 0.3, 0.4, 0.1])

True

In [200]:
y = np.log([[2, 3, 4]])
centered_softmax_forward(y)

Array([[0.2, 0.3, 0.4, 0.1]], dtype=float32)

In [192]:
y = np.log([[2, 3, 4],
            [2, 3, 4],
            [2, 3, 4]])

centered_softmax_forward(y)

array([[0.2, 0.3, 0.4, 0.1],
       [0.2, 0.3, 0.4, 0.1],
       [0.2, 0.3, 0.4, 0.1]])