In [1]:
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from jax.nn import softmax
from scipy.optimize import minimize
from tqdm.auto import tqdm

def _normalize(u, axis=0, eps=1e-15):
    """Normalizes the values within the axis in a way that they sum up to 1.
    Args:
        u: Input array to normalize.
        axis: Axis over which to normalize.
        eps: Minimum value threshold for numerical stability.
    Returns:
        Tuple of the normalized values, and the normalizing denominator.
    """
    u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
    c = u.sum(axis=axis)
    c = jnp.where(c == 0, 1, c)
    return u / c, c


# Helper functions for the two key filtering steps
def _condition_on(probs, ll):
    """Condition on new emissions, given in the form of log likelihoods
    for each discrete state, while avoiding numerical underflow.
    Args:
        probs(k): prior for state k
        ll(k): log likelihood for state k
    Returns:
        probs(k): posterior for state k
    """
    ll_max = ll.max()
    new_probs = probs * jnp.exp(ll - ll_max)
    new_probs, norm = _normalize(new_probs)
    log_norm = jnp.log(norm) + ll_max
    return new_probs, log_norm


def _predict(probs, A):
    return A.T @ probs

def _get_transition_matrix(transition_matrix, t):
        return transition_matrix[t] if transition_matrix.ndim == 3 else transition_matrix


@partial(jax.jit)
def hmm_filter(
    initial_distribution,
    transition_matrix,
    log_likelihoods,
):
    r"""Forwards filtering
    Transition matrix may be either 2D (if transition probabilities are fixed) or 3D
    if the transition probabilities vary over time. Alternatively, the transition
    matrix may be specified via `transition_fn`, which takes in a time index $t$ and
    returns a transition matrix.
    Args:
        initial_distribution: $p(z_1 \mid u_1, \theta)$
        transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
        log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
    Returns:
        filtered posterior distribution
    """
    num_timesteps, num_states = log_likelihoods.shape

    def _step(carry, t):
        log_normalizer, predicted_probs = carry

        ll = log_likelihoods[t]

        filtered_probs, log_norm = _condition_on(predicted_probs, ll)
        log_normalizer += log_norm
        predicted_probs_next = _predict(filtered_probs, transition_matrix)

        return (log_normalizer, predicted_probs_next), (filtered_probs, predicted_probs)

    carry = (0.0, initial_distribution)
    (log_normalizer, _), (filtered_probs, predicted_probs) = jax.lax.scan(
        _step, carry, jnp.arange(num_timesteps)
    )

    return log_normalizer, filtered_probs, predicted_probs


def fit_regression(design_matrix, weights, spikes):
    @jax.jit
    def neglogp(
        coefficients, spikes=spikes, design_matrix=design_matrix, weights=weights
    ):
        conditional_intensity = jnp.exp(design_matrix @ coefficients)
        conditional_intensity = jnp.clip(conditional_intensity, a_min=1e-15, a_max=None)
        log_likelihood = weights * jax.scipy.stats.poisson.logpmf(
            spikes, conditional_intensity
        )
        return -log_likelihood.sum()

    dlike = jax.grad(neglogp)

    initial_condition = np.array([np.log(np.average(spikes, weights=weights))])
    initial_condition = np.concatenate(
        [initial_condition, np.zeros(design_matrix.shape[1] - 1)]
    )

    res = minimize(neglogp, x0=initial_condition, method="BFGS", jac=dlike)

    return res.x


def centered_softmax_forward(y):
    """`softmax(x) = exp(x-c) / sum(exp(x-c))` where c is the last coordinate

    Example
    -------
    > y = np.log([2, 3, 4])
    > np.allclose(centered_softmax_forward(y), [0.2, 0.3, 0.4, 0.1])
    """
    if y.ndim == 1:
        y = jnp.append(y, 0)
    else:
        y = jnp.column_stack((y, jnp.zeros((y.shape[0],))))

    return softmax(y, axis=-1)


def centered_softmax_inverse(y):
    """`softmax(x) = exp(x-c) / sum(exp(x-c))` where c is the last coordinate

    Example
    -------
    > y = np.asarray([0.2, 0.3, 0.4, 0.1])
    > np.allclose(np.exp(centered_softmax_inverse(y)), np.asarray([2,3,4]))
    """
    return jnp.log(y[..., :-1]) - jnp.log(y[..., [-1]])


@jax.jit
def neglogp(
    unconstrained_parameters,
    initial_distribution,
    observation_log_likelihood,
):

    # Unpack parameters
    # initial_distribution = centered_softmax_forward(unconstrained_parameters)
    n_states = initial_distribution.shape[0]
    transition_matrix = centered_softmax_forward(unconstrained_parameters.reshape((n_states, n_states - 1)))

    marginal_log_likelihood, _, _ = hmm_filter(
        initial_distribution, transition_matrix, observation_log_likelihood
    )
    nll = -1.0 * marginal_log_likelihood
    jax.debug.print("neg. log like.: {}", nll)

    return nll

In [2]:
from hmmlearn import hmm


n_states = 3
sampling_frequency = 500.0
n_time = 100_000

model1 = hmm.PoissonHMM(n_components=n_states)

model1.startprob_ = np.array([0.10, 0.10, 0.80])

model1.transmat_ = np.array(
    [[0.99, 0.005, 0.005],
     [0.005, 0.99, 0.005],
     [0.005, 0.005, 0.99]]
)

model1.lambdas_ = np.array([10.0, 25.0, 50.0])[:, np.newaxis] / sampling_frequency

spikes1, state_sequence1 = model1.sample(n_time)

In [3]:
# initial_distribution = np.ones((n_states,)) / n_states
initial_distribution = model1.startprob_

transition_matrix = np.array(
    [[0.90, 0.01, 0.01],
     [0.01, 0.90, 0.01],
     [0.01, 0.01, 0.90]]
)
local_rates = model1.lambdas_.squeeze()

x0 = centered_softmax_inverse(transition_matrix).ravel()

In [4]:
import scipy.optimize

observation_log_likelihood = scipy.stats.poisson.logpmf(spikes1, local_rates[np.newaxis])

res = scipy.optimize.minimize(
    neglogp,
    x0=x0,
    method="BFGS",
    jac=jax.grad(neglogp),
    args=(
        initial_distribution,
        observation_log_likelihood,
    ),
    options={
        "disp": True,
        'gtol': 1e-2,  # Adjust gradient tolerance value
        
    },    
)

np.array(centered_softmax_forward(res.x.reshape((n_states, n_states - 1)))) - model1.transmat_

neg. log like.: 20953.4609375
neg. log like.: 20953.4609375
neg. log like.: 20962.16796875
neg. log like.: 20962.16796875
neg. log like.: 20930.037109375
neg. log like.: 20930.037109375
neg. log like.: 20920.26953125
neg. log like.: 20920.26953125
neg. log like.: 20914.458984375
neg. log like.: 20914.458984375
neg. log like.: 20912.802734375
neg. log like.: 20912.802734375
neg. log like.: 20942.30078125
neg. log like.: 20942.30078125
neg. log like.: 20910.52734375
neg. log like.: 20910.52734375
neg. log like.: 20915.470703125
neg. log like.: 20915.470703125
neg. log like.: 20910.072265625
neg. log like.: 20910.072265625
neg. log like.: 20909.388671875
neg. log like.: 20909.388671875
neg. log like.: 20909.22265625
neg. log like.: 20909.22265625
neg. log like.: 20908.787109375
neg. log like.: 20908.787109375
neg. log like.: 20909.46875
neg. log like.: 20909.46875
neg. log like.: 20908.705078125
neg. log like.: 20908.705078125
neg. log like.: 20908.759765625
neg. log like.: 20908.75976562

array([[ 0.00318886, -0.00355115,  0.00036236],
       [-0.00177907,  0.00263436, -0.00085529],
       [ 0.00028294,  0.00041303, -0.00069593]])

In [5]:
res

      fun: 20908.294921875
 hess_inv: array([[ 0.03045055,  0.09985725, -0.00478979, -0.01431632, -0.01596873,
         0.01098402],
       [ 0.09985725,  0.71826047,  0.19469098, -0.07738144, -0.07425328,
         0.09445286],
       [-0.00478979,  0.19469098,  0.29474507,  0.12904511, -0.05047152,
        -0.0399247 ],
       [-0.01431632, -0.07738144,  0.12904511,  0.19918681,  0.01372938,
        -0.13622989],
       [-0.01596873, -0.07425328, -0.05047152,  0.01372938,  0.07126171,
        -0.05240719],
       [ 0.01098402,  0.09445286, -0.0399247 , -0.13622989, -0.05240719,
         0.14263034]])
      jac: array([-0.95703125,  0.31788453, -0.35074598,  0.296875  , -0.16613615,
        0.04679678], dtype=float32)
  message: 'Desired error not necessarily achieved due to precision loss.'
     nfev: 91
      nit: 14
     njev: 80
   status: 2
  success: False
        x: array([ 5.22151741, -1.30863249, -0.25216175,  5.47852979, -5.23251959,
       -5.20819207])

In [6]:
import jax.scipy.optimize

res = jax.scipy.optimize.minimize(
    neglogp,
    x0=x0,
    method="BFGS",
    args=(
        initial_distribution,
        observation_log_likelihood,
    ),
    options={"gtol": 1e-2}
)

np.array(centered_softmax_forward(res.x.reshape((n_states, n_states - 1)))) - model1.transmat_

neg. log like.: 20953.4609375
neg. log like.: 20962.162109375
neg. log like.: 20930.45703125
neg. log like.: 20919.951171875
neg. log like.: 20917.216796875
neg. log like.: 20911.14453125


array([[ 0.00359846, -0.00232967, -0.00126887],
       [ 0.00013884, -0.00241481,  0.0022759 ],
       [-0.00174887,  0.00284055, -0.00109165]])

In [7]:
res

OptimizeResults(x=Array([ 5.5846224 , -0.33450836, -0.34774008,  4.910695  , -5.7176    ,
       -4.837292  ], dtype=float32), success=Array(False, dtype=bool), status=Array(3, dtype=int32, weak_type=True), fun=Array(20911.145, dtype=float32), jac=Array([10.410156 , -4.052677 ,  2.5904493, -9.267578 , -4.9828143,
       -8.101913 ], dtype=float32), hess_inv=Array([[ 0.09723077,  0.2433599 , -0.0471726 , -0.06384337,  0.01772271,
        -0.04033443],
       [ 0.24335985,  0.87410516,  0.13700712, -0.19436541,  0.03658639,
        -0.10066385],
       [-0.0471726 ,  0.13700719,  0.71217376,  0.3188084 , -0.25752988,
         0.07181346],
       [-0.06384335, -0.19436546,  0.31880838,  0.54269755, -0.01272446,
        -0.2940241 ],
       [ 0.01772273,  0.0365864 , -0.25752985, -0.01272443,  0.30734807,
        -0.3491177 ],
       [-0.04033444, -0.10066386,  0.07181344, -0.2940241 , -0.3491177 ,
         0.64298683]], dtype=float32), nfev=Array(6, dtype=int32, weak_type=True), njev=Arra

In [8]:
res.nfev

Array(6, dtype=int32, weak_type=True)

In [9]:
res.success

Array(False, dtype=bool)

In [10]:
marginal_log_likelihood, filtered_probs, _ = hmm_filter(
    model1.startprob_, model1.transmat_, observation_log_likelihood
)

-1.0 * marginal_log_likelihood

Array(20912.73, dtype=float32)

In [11]:
from src.hmm import forward

causal_posterior, _, marginal_log_likelihood2 = forward(model1.startprob_, observation_log_likelihood, model1.transmat_)

-1.0 * marginal_log_likelihood2

20912.573858922045