# Particle Filter with JAX

## Notation

### Functions

- `state_lpdf(x_curr, x_last, theta)`: Log-density of $p(x_t | x_{t-1}, \theta)$.
- `state_sample(x_last, theta)`: Sample from $p(x_t | x_{t-1}, \theta)$.
- `meas_lpdf(y_curr, x_curr, theta)`: Log-density of $p(y_t | x_t, \theta)$.
- `meas_sample(x_curr, theta)`: Sample from $p(y_t | x_t, \theta)$.

### Dimensions

- `n_obs`: Number of time points.
- `n_state`: Number of state dimensions.
- `n_meas`: Number of measured dimensions.
- `n_particle`: Number of particles.

## Example: Brownian motion with drift

The model is
$$
\newcommand{\N}{\mathcal{N}}
\newcommand{\dt}{\Delta t}
\begin{aligned}
x_0 & \sim \pi(x_0) \\
x_t & \sim \N(x_{t-1} + \mu \dt, \sigma^2 \dt) \\
y_t & \sim \N(x_t, \tau^2).
\end{aligned}
$$

Therefore, `n_state` = `n_meas` = 1.  

Note that with $\pi(x_0) \propto 1$, we may condition on $y_0$ and obtain $x_0 \mid y_0 \sim \N(y_0, \tau^2)$.

### Using **NumPy** and **SciPy**

In [12]:
# %load bm_model.py
"""
Brownian motion state space model.

The model is:

```
x_0 ~ pi(x_0) \propto 1
x_t ~ N(x_{t-1} + mu * dt, sigma * sqrt(dt))
y_t ~ N(x_t, tau)
```

The parameter values are `theta = (mu, sigma, tau)`, and `dt` is a global constant.
"""

import numpy as np
import scipy as sp
import scipy.stats


# state-space dimensions
n_meas = 1
n_state = 1


def state_lpdf(x_curr, x_prev, theta):
    """
    Calculates the log-density of `p(x_curr | x_prev, theta)`.

    Args:
        x_curr: State variable at current time `t`.
        x_prev: State variable at previous time `t-1`.
        theta: Parameter value.

    Returns:
        The log-density of `p(x_curr | x_prev, theta)`.
    """
    mu = theta[0]
    sigma = theta[1]
    return sp.stats.norm.logpdf(x_curr, loc=x_prev + mu * dt, scale=sigma * np.sqrt(dt))


def state_sample(x_prev, theta):
    """
    Samples from `x_curr ~ p(x_curr | x_prev, theta)`.

    Args:
        x_prev: State variable at previous time `t-1`.
        theta: Parameter value.

    Returns:
        Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
    """
    mu = theta[0]
    sigma = theta[1]
    return sp.stats.norm.rvs(loc=x_prev + mu * dt, scale=sigma * np.sqrt(dt))


def meas_lpdf(y_curr, x_curr, theta):
    """
    Log-density of `p(y_curr | x_curr, theta)`.

    Args:
        y_curr: Measurement variable at current time `t`.
        x_curr: State variable at current time `t`.
        theta: Parameter value.

    Returns
        The log-density of `p(x_curr | x_prev, theta)`.
    """
    tau = theta[2]
    return sp.stats.norm.logpdf(y_curr, loc=x_curr, scale=tau)


def meas_sample(x_curr, theta):
    """
    Sample from `p(y_curr | x_curr, theta)`.

    Args:
        x_curr: State variable at current time `t`.
        theta: Parameter value.

    Returns:
        Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
    """
    tau = theta[2]
    return sp.stats.norm.rvs(loc=x_curr, scale=tau)


In [14]:
# %load particle_filter.py
"""
Prototype for particle filter using NumPy/SciPy.

The API requires the user to define the following functions:

- `state_lpdf(x_curr, x_last, theta)`: Log-density of `p(x_t | x_{t-1}, theta)`.
- `state_sample(x_last, theta)`: Sample from `x_curr ~ p(x_t | x_{t-1}, theta)`.
- `meas_lpdf(y_curr, x_curr, theta)`: Log-density of `p(y_t | x_t, theta)`.
- `meas_sample(x_curr, theta)`: Sample from `y_curr ~ p(y_t | x_t, theta)`.

For now, additional inputs are specified as global constants.

The provided functions are:

- `meas_sim(n_obs, x_init, theta)`: Obtain a sample from `y_meas = (y_1, ..., y_T)` and `x_state = (x_1, ..., x_T)`.
- `particle_filter(y_meas, theta, n_particles): Run the particle filter.
- `particle_loglik(logw_particles)`: Compute the particle filter marginal loglikelihoood.
- `particle_smooth(logw, X_particles, ancestor_particles, n_sample)`: Posterior sampling from the particle filter distribution of `p(x_state | y_meas, theta)`.
- `particle_resample(logw)`: A rudimentary particle resampling method.
"""

import numpy as np
import scipy as sp
import scipy.stats


def meas_sim(n_obs, x_init, theta):
    """
    Simulate data from the state-space model.

    Args:
        n_obs: Number of observations to generate.
        x_init: Initial state value at time `t = 0`.
        theta: Parameter value.

    Returns:
        y_meas: The sequence of measurement variables `y_meas = (y_1, ..., y_T)`, where `T = n_obs`.
        x_state: The sequence of state variables `x_state = (x_1, ..., x_T)`, where `T = n_obs`.
    """
    y_meas = np.zeros((n_obs, n_meas))
    x_state = np.zeros((n_obs, n_state))
    x_prev = x_init
    for t in range(n_obs):
        x_state[t] = state_sample(x_prev, theta)
        y_meas[t] = meas_sample(x_state[t], theta)
        x_prev = x_state[t]
    return y_meas, x_state


def particle_resample(logw):
    """
    Particle resampler.

    This basic one just does a multinomial sampler, i.e., sample with replacement proportional to weights.

    Args:
        logw: Vector of `n_particles` unnormalized log-weights.

    Returns:
        Vector of `n_particles` integers between 0 and `n_particles-1`, sampled with replacement with probability vector `exp(logw) / sum(exp(logw))`.
    """
    wgt = np.exp(logw - np.max(logw))
    prob = wgt / np.sum(wgt)
    n_particles = logw.size
    return np.random.choice(np.arange(n_particles), size=n_particles, p=prob)


def particle_filter(y_meas, theta, n_particles):
    """
    Apply particle filter for given value of `theta`.

    Closely follows Algorithm 2 of https://arxiv.org/pdf/1306.3277.pdf.

    FIXME: Uses a hard-coded prior for initial state variable `x_state[0]`.  Need to make this more general.

    Args:
        y_meas: The sequence of `n_obs` measurement variables `y_meas = (y_1, ..., y_T)`, where `T = n_obs`.
        theta: Parameter value.
        n_particles: Number of particles.

    Returns:
        A dictionary with elements:
            - `X_particles`: An `ndarray` with leading dimensions `(n_obs, n_particles)` containing the state variable particles.
            - `logw_particles`: An `ndarray` of shape `(n_obs, n_particles)` giving the unnormalized log-weights of each particle at each time point.
            - `ancestor_particles`: An integer `ndarray` of shape `(n_obs, n_particles)` where each element gives the index of the particle's ancestor at the previous time point.  Since the first time point does not have ancestors, the first row of `ancestor_particles` contains all `-1`.
    """
    # memory allocation
    n_obs = y_meas.shape[0]
    X_particles = np.zeros((n_obs, n_particles, n_state))
    logw_particles = np.zeros((n_obs, n_particles))
    ancestor_particles = np.zeros((n_obs, n_particles), dtype=int)
    ancestor_particles[0] = -1  # initial particles have no ancestors
    # initial time point
    # FIXME: Hard-coded flat prior on x_0.  Make this more general.
    for i_part in range(n_particles):
        X_particles[0, i_part, :] = meas_sample(y_meas[0, :], theta)
        # sample directly from posterior p(x_0 | y_0, theta)
        logw_particles[0, i_part] = 0.
    # subsequent time points
    for t in range(1, n_obs):
        # resampling step
        ancestor_particles[t] = particle_resample(logw_particles[t-1])
        for i_part in range(n_particles):
            X_particles[t, i_part, :] = state_sample(
                X_particles[t-1, ancestor_particles[t, i_part], :], theta
            )
            logw_particles[t, i_part] = meas_lpdf(
                y_meas[t, :], X_particles[t, i_part, :], theta
            )
    return {
        "X_particles": X_particles,
        "logw_particles": logw_particles,
        "ancestor_particles": ancestor_particles
    }


def particle_loglik(logw_particles):
    """
    Calculate particle filter marginal loglikelihood.

    Args:
        logw_particles: An `ndarray` of shape `(n_obs, n_particles)` giving the unnormalized log-weights of each particle at each time point.        

    Returns:
        Particle filter approximation of 
        ```
        log p(y_meas | theta) = log int p(y_meas | x_state, theta) * p(x_state | theta) dx_state
        ```
    """
    return np.sum(sp.special.logsumexp(logw_particles, axis=1))


def particle_smooth(logw, X_particles, ancestor_particles, n_sample=1):
    """
    Basic particle smoothing algorithm.

    Samples from posterior distribution `p(x_state | x_meas, theta)`.

    Args:
        logw: Vector of `n_particles` unnormalized log-weights at the last time point `t = n_obs`.
        X_particles: An `ndarray` with leading dimensions `(n_obs, n_particles)` containing the state variable particles.        
        ancestor_particles: An integer `ndarray` of shape `(n_obs, n_particles)` where each element gives the index of the particle's ancestor at the previous time point.
        n_sample: Number of draws of `x_state` to return.

    Returns:
        An `ndarray` with leading dimension `n_sample` corresponding to as many samples from the particle filter approximation to the posterior distribution `p(x_state | x_meas, theta)`.
    """
    wgt = np.exp(logw - np.max(logw))
    prob = wgt / np.sum(wgt)
    n_particles = logw.size
    n_obs = X_particles.shape[0]
    n_state = X_particles.shape[2]
    x_state = np.zeros((n_sample, n_obs, n_state))
    for i_samp in range(n_sample):
        i_part = np.random.choice(np.arange(n_particles), size=1, p=prob)
        # i_part_T = i_part
        x_state[i_samp, n_obs-1] = X_particles[n_obs-1, i_part, :]
        for i_obs in reversed(range(n_obs-1)):
            i_part = ancestor_particles[i_obs+1, i_part]
            x_state[i_samp, i_obs] = X_particles[i_obs, i_part, :]
    return x_state  # , i_part_T


In [16]:
# %load test_particle_filter.py
# parameter values
mu = 5
sigma = 1
tau = .1
theta = np.array([mu, sigma, tau])

# data specification
dt = .1
n_obs = 5
x_init = np.array([0.])

# simulate data
y_meas, x_state = meas_sim(n_obs, x_init, theta)

print("y_meas = \n", y_meas)
print("x_state = \n", x_state)

n_particles = 7
pf_out = particle_filter(y_meas, theta, n_particles)
pf_out = particle_filter(y_meas, theta, n_particles)
pf_out = particle_filter(y_meas, theta, n_particles)

print("pf_out = \n", pf_out)

# calculate marginal loglikelihood
pf_loglik = particle_loglik(pf_out["logw_particles"])

print("pf_loglik = \n", pf_loglik)

# sample from posterior `p(x_{0:T} | y_{0:T}, theta)`
n_sample = 11
X_state = particle_smooth(
    pf_out["logw_particles"][n_obs-1],
    pf_out["X_particles"],
    pf_out["ancestor_particles"],
    n_sample
)

print("X_state = \n", X_state)


y_meas = 
 [[0.76371425]
 [0.88715404]
 [1.10640127]
 [1.36352618]
 [1.95148437]]
x_state = 
 [[0.66073978]
 [0.78324502]
 [0.95774334]
 [1.35223807]
 [1.95785123]]
pf_out = 
 {'X_particles': array([[[0.77350637],
        [0.61849112],
        [0.78837221],
        [0.69694015],
        [0.75922677],
        [0.61105708],
        [0.81612899]],

       [[1.33094042],
        [0.83430704],
        [1.72074152],
        [1.00563761],
        [1.37481499],
        [0.7076549 ],
        [0.93824445]],

       [[1.46630755],
        [1.76465845],
        [1.56231864],
        [1.49361636],
        [1.20354607],
        [1.34611214],
        [1.7137133 ]],

       [[1.77178989],
        [2.2083558 ],
        [1.99854547],
        [1.53933115],
        [1.86793308],
        [1.66248025],
        [1.62665892]],

       [[2.20552559],
        [2.04343909],
        [1.87094663],
        [2.28008091],
        [2.04926174],
        [2.60712631],
        [2.51485203]]]), 'logw_particles': array([[ 

### Using JAX

In [18]:
# %load bm_model_jax.py
"""
Brownian motion state space model in JAX.
"""

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random


# state-space dimensions
n_meas = 1
n_state = 1


def state_lpdf(x_curr, x_prev, theta):
    """
    Calculates the log-density of `p(x_curr | x_prev, theta)`.

    Args:
        x_curr: State variable at current time `t`.
        x_prev: State variable at previous time `t-1`.
        theta: Parameter value.

    Returns:
        The log-density of `p(x_curr | x_prev, theta)`.
    """
    mu = theta[0]
    sigma = theta[1]
    return jsp.stats.norm.logpdf(x_curr, loc=x_prev + mu * dt, scale=sigma * jnp.sqrt(dt))


def state_sample(x_prev, theta, key):
    """
    Samples from `x_curr ~ p(x_curr | x_prev, theta)`.

    Args:
        x_prev: State variable at previous time `t-1`.
        theta: Parameter value.
        key: PRNG key.

    Returns:
        Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
    """
    mu = theta[0]
    sigma = theta[1]
    x_mean = x_prev + mu * dt
    x_sd = sigma * jnp.sqrt(dt)
    return x_mean + x_sd * random.normal(key=key)


def meas_lpdf(y_curr, x_curr, theta):
    """
    Log-density of `p(y_curr | x_curr, theta)`.

    Args:
        y_curr: Measurement variable at current time `t`.
        x_curr: State variable at current time `t`.
        theta: Parameter value.

    Returns
        The log-density of `p(x_curr | x_prev, theta)`.
    """
    tau = theta[2]
    return jsp.stats.norm.logpdf(y_curr, loc=x_curr, scale=tau)


def meas_sample(x_curr, theta, key):
    """
    Sample from `p(y_curr | x_curr, theta)`.

    Args:
        x_curr: State variable at current time `t`.
        theta: Parameter value.
        key: PRNG key.

    Returns:
        Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
    """
    tau = theta[2]
    return x_curr + tau * random.normal(key=key)


In [20]:
# %load particle_filter_jax.py
"""
Particle filter in JAX.

Uses the same API as NumPy/SciPy version.
"""

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random


def meas_sim(n_obs, x_init, theta, key):
    """
    Simulate data from the state-space model.

    Args:
        n_obs: Number of observations to generate.
        x_init: Initial state value at time `t = 0`.
        theta: Parameter value.
        key: PRNG key.

    Returns:
        y_meas: The sequence of measurement variables `y_meas = (y_1, ..., y_T)`, where `T = n_obs`.
        x_state: The sequence of state variables `x_state = (x_1, ..., x_T)`, where `T = n_obs`.
    """
    y_meas = jnp.zeros((n_obs, n_meas))
    x_state = jnp.zeros((n_obs, n_state))
    x_prev = x_init
    for t in range(n_obs):
        key, *subkeys = random.split(key, num=3)
        x_state = x_state.at[t].set(state_sample(x_prev, theta, subkeys[0]))
        y_meas = y_meas.at[t].set(meas_sample(x_state[t], theta, subkeys[1]))
        x_prev = x_state[t]
    return y_meas, x_state


def particle_resample(logw, key):
    """
    Particle resampler.

    This basic one just does a multinomial sampler, i.e., sample with replacement proportional to weights.

    Args:
        logw: Vector of `n_particles` unnormalized log-weights.
        key: PRNG key.

    Returns:
        Vector of `n_particles` integers between 0 and `n_particles-1`, sampled with replacement with probability vector `exp(logw) / sum(exp(logw))`.
    """
    wgt = jnp.exp(logw - jnp.max(logw))
    prob = wgt / jnp.sum(wgt)
    n_particles = logw.size
    return random.choice(key,
                         a=jnp.arange(n_particles), shape=(n_particles,), p=prob)


def particle_filter(y_meas, theta, n_particles, key):
    """
    Apply particle filter for given value of `theta`.

    Closely follows Algorithm 2 of https://arxiv.org/pdf/1306.3277.pdf.

    FIXME: Uses a hard-coded prior for initial state variable `x_state[0]`.  Need to make this more general.

    Args:
        y_meas: The sequence of `n_obs` measurement variables `y_meas = (y_1, ..., y_T)`, where `T = n_obs`.
        theta: Parameter value.
        n_particles: Number of particles.
        key: PRNG key.

    Returns:
        A dictionary with elements:
            - `X_particles`: An `ndarray` with leading dimensions `(n_obs, n_particles)` containing the state variable particles.
            - `logw_particles`: An `ndarray` of shape `(n_obs, n_particles)` giving the unnormalized log-weights of each particle at each time point.
            - `ancestor_particles`: An integer `ndarray` of shape `(n_obs, n_particles)` where each element gives the index of the particle's ancestor at the previous time point.  Since the first time point does not have ancestors, the first row of `ancestor_particles` contains all `-1`.
    """
    # memory allocation
    n_obs = y_meas.shape[0]
    X_particles = jnp.zeros((n_obs, n_particles, n_state))
    logw_particles = jnp.zeros((n_obs, n_particles))
    ancestor_particles = jnp.zeros((n_obs, n_particles), dtype=int)
    # initial particles have no ancestors
    ancestor_particles = ancestor_particles.at[0].set(-1)
    # initial time point
    # FIXME: Hard-coded flat prior on x_0.  Make this more general.
    key, *subkeys = random.split(key, num=n_particles+1)
    X_particles = X_particles.at[0].set(
        jax.vmap(lambda k: meas_sample(y_meas[0], theta, k))(
            jnp.array(subkeys)
        )
    )
    # sample directly from posterior p(x_0 | y_0, theta)
    logw_particles = logw_particles.at[0].set(0.)
    # subsequent time points
    for t in range(1, n_obs):
        # resampling step
        key, subkey = random.split(key)
        ancestor_particles = ancestor_particles.at[t].set(
            particle_resample(logw_particles[t-1], subkey)
        )
        # update
        key, *subkeys = random.split(key, num=n_particles+1)
        X_particles = X_particles.at[t].set(
            jax.vmap(lambda xs, k: state_sample(xs, theta, k))(
                X_particles[t-1, ancestor_particles[t]], jnp.array(subkeys)
            )
        )
        logw_particles = logw_particles.at[t].set(
            jnp.squeeze(
                jax.vmap(lambda xs: meas_lpdf(y_meas[t], xs, theta))(
                    X_particles[t]
                )
            )
        )
    return {
        "X_particles": X_particles,
        "logw_particles": logw_particles,
        "ancestor_particles": ancestor_particles
    }


In [47]:
# %load test_particle_filter_jax.py
key = random.PRNGKey(0)

# parameter values
mu = 5
sigma = 1
tau = .1
theta = jnp.array([mu, sigma, tau])

print(theta)

# data specification
dt = .1
n_obs = 5
x_init = jnp.array([0.])

# simulate data
key, subkey = random.split(key)
y_meas, x_state = meas_sim(n_obs, x_init, theta, subkey)

print("y_meas = \n", y_meas)
print("x_state = \n", x_state)

# run particle filter
n_particles = 7
key, subkey = random.split(key)
particle_filter_jitted = jax.jit(particle_filter, static_argnums=(2,))
pf_out = particle_filter(y_meas, theta, n_particles, subkey)
pf_out_jitted = particle_filter_jitted(y_meas, theta, n_particles, subkey)

print("pf_out = \n", pf_out)
print("max diff between pf_out and pf_out_jitted = \n", {k:jnp.max(jnp.abs(pf_out[k] - pf_out_jitted[k])) for k in pf_out.keys()})

[5.  1.  0.1]
y_meas = 
 [[0.44947725]
 [0.68188113]
 [0.91331539]
 [1.05404084]
 [1.68557842]]
x_state = 
 [[0.4394778 ]
 [0.60337457]
 [0.88978651]
 [1.07383932]
 [1.44880131]]
pf_out = 
 {'X_particles': DeviceArray([[[0.38734738],
              [0.42851383],
              [0.44831397],
              [0.48654838],
              [0.68515367],
              [0.43273745],
              [0.57517652]],

             [[1.28661928],
              [1.42292051],
              [1.3405227 ],
              [1.04493578],
              [0.83283796],
              [0.76433738],
              [0.3247642 ]],

             [[1.11614122],
              [1.18730814],
              [1.48856869],
              [0.56884518],
              [0.97734757],
              [1.51114774],
              [1.73332512]],

             [[1.44909828],
              [1.4053407 ],
              [1.56266414],
              [1.74703783],
              [1.70365507],
              [2.07926504],
              [1.35491909]],

  

In [32]:
%timeit particle_filter(y_meas, theta, n_particles, subkey)

116 ms ± 1.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [33]:
%timeit particle_filter_jitted(y_meas, theta, n_particles, subkey)

67.2 µs ± 318 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [34]:
(116 * 1e-3) / (67.2 * 1e-6)

1726.1904761904764

## Scratch

In [48]:
state_lpdf(x_state[1], x_state[0], theta)

DeviceArray([-0.33247287], dtype=float64)

In [186]:
import numpy as np
import scipy as sp
import scipy.stats

def state_lpdf(x_curr, x_last, theta):
    mu = theta[0]
    sigma = theta[1]
    return sp.stats.norm.logpdf(x_curr, loc = x_last + mu * dt, scale = sigma * np.sqrt(dt))


def state_sample(x_last, theta):
    mu = theta[0]
    sigma = theta[1]
    return sp.stats.norm.rvs(loc = x_last + mu * dt, scale = sigma * np.sqrt(dt))

def meas_lpdf(y_curr, x_curr, theta):
    tau = theta[2]
    return sp.stats.norm.logpdf(y_curr, loc = x_curr, scale = tau)

def meas_sample(x_curr, theta):
    tau = theta[2]
    return sp.stats.norm.rvs(loc = x_curr, scale = tau)

# storage
#
# first do it column-major
# n_state, n_particles, n_obs
# [:,:, i_obs] represents the state of the pf  up to a given point
# [:,i_part, i_obs] is the calculation for each particle at a given point
#
# now row-major

# first a helper function
def np_intarray(dims):
    return np.reshape(np.arange(np.prod(dims))+0., dims, order = 'C')

n_meas = 1
n_state = 1
n_particles = 7
n_obs = 5
n_tot = n_state * n_particles * n_obs
#X_particles = np.arange(n_tot)
#X_particles = np.reshape(X_particles, [n_obs, n_particles, n_state], order = 'C')
X_particles = np_intarray([n_obs, n_particles, n_state])
i_obs = 0
# X_particles[i_obs]

# weights
#logw_particles = np.reshape(np.arange(n_particles * n_obs), [n_obs, n_particles])
logw_particles = np_intarray([n_obs, n_particles])
logw_particles[i_obs]

# ancestors
ant_particles = np_intarray([n_obs, n_particles]).astype(int)
# first set of particles have no ancestors
ant_particles[0] = 0

# let's try it
mu = 5
sigma = 1
tau = .1
theta = np.array([mu, sigma, tau])
dt = .1

# first simulate data
y_obs = np_intarray([n_obs, n_meas])
x_lat = np_intarray([n_obs, n_state])
x_prev = 0.
for t in range(n_obs):
    x_lat[t] = state_sample(x_prev, theta)
    y_obs[t] = meas_sample(x_lat[t], theta)
    x_prev = x_lat[t]
x_lat_true = x_lat # reuse this as a variable name
y_obs

array([[0.34141049],
       [0.74321696],
       [0.83085765],
       [1.98326492],
       [2.79380972]])

In [187]:
# now particle filter

# sample from normalized weights with replacement
def particle_resample(logw):
    mx = np.max(logw)
    wgt = np.exp(logw - mx)
    n_particles = logw.size
    #return sp.stats.multinomial.rvs(n_particles, wgt/np.sum(wgt))
    return np.random.choice(np.arange(n_particles), size = n_particles, p = wgt / np.sum(wgt))

#sp.stats.multinomial.rvs(1, [.25, .25, .25, .25], 10)
#np.random.choice(np.arange(4), size = 10, p = [.25, .25, .25, .25])

#particle_resample(np.array([0., 0., 0.]))

# using flat prior on x_0 
for i_part in range(n_particles):
    X_particles[0,i_part,:] = meas_sample(y_obs[0,:], theta)
    #logw_particles[0,i_part] = meas_lpdf(X_particles[0,i_part,:], y_obs[0], theta)
    logw_particles[0,i_part] = 0. # sample directly from posterior p(x_0 | y_0, theta)

# remaining observations
for t in range(1, n_obs):
    # resampling step
    ant_particles[t] = particle_resample(logw_particles[t-1])
    for i_part in range(n_particles):
        X_particles[t,i_part,:] = state_sample(X_particles[t-1,ant_particles[t,i_part],:], theta)
        logw_particles[t, i_part] = meas_lpdf(y_obs[t,:], X_particles[t,i_part,:], theta)

print(X_particles)
print(logw_particles)
print(ant_particles)

[[[0.25876451]
  [0.31918554]
  [0.26557652]
  [0.33072027]
  [0.38542672]
  [0.3636304 ]
  [0.42153657]]

 [[1.11239371]
  [1.29589014]
  [0.95854899]
  [1.42828306]
  [0.53992602]
  [0.63031026]
  [0.71932819]]

 [[0.70597072]
  [1.38912009]
  [1.84765476]
  [0.94279953]
  [1.72416633]
  [1.40544201]
  [1.0829206 ]]

 [[1.48140934]
  [1.75805587]
  [1.10969347]
  [1.86911294]
  [1.16602461]
  [1.93583039]
  [1.02073788]]

 [[2.30232723]
  [2.48872364]
  [2.02646744]
  [2.44285898]
  [1.9745851 ]
  [2.83355214]
  [2.41469545]]]
[[  0.           0.           0.           0.           0.
    0.           0.        ]
 [ -5.43092725 -13.88873598  -0.93474776 -22.08213157  -0.68271372
    0.74625047   1.3551129 ]
 [  0.60380926 -14.19920086 -50.31017205   0.75709729 -38.5163734
  -15.12371285  -1.79313992]
 [-11.20930491  -1.15230939 -36.77270752   0.73211276 -32.01043983
    1.2711448  -44.93926934]
 [-10.69410502  -3.27022919 -28.05706219  -4.77467435 -32.17280205
    1.30467354  -5.8027

In [163]:
ant_particles

array([[0, 0, 0, 0, 0, 0, 0],
       [2, 0, 1, 1, 1, 1, 1],
       [0, 0, 2, 0, 0, 4, 1],
       [0, 0, 0, 6, 1, 0, 0],
       [0, 0, 0, 7, 0, 0, 0]])

In [None]:
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import HtmlFormatter
import IPython

with open('particle_filter.py') as f:
    code = f.read()

formatter = HtmlFormatter()
IPython.display.HTML('<style type="text/css">{}</style>{}'.format(
    formatter.get_style_defs('.highlight'),
    highlight(code, PythonLexer(), formatter)))