# Particle Filter with JAX

## Notation

### Functions

- `state_lpdf(x_curr, x_last, theta)`: Log-density of $p(x_t | x_{t-1}, \theta)$.
- `state_sample(x_last, theta)`: Sample from $p(x_t | x_{t-1}, \theta)$.
- `meas_lpdf(y_curr, x_curr, theta)`: Log-density of $p(y_t | x_t, \theta)$.
- `meas_sample(x_curr, theta)`: Sample from $p(y_t | x_t, \theta)$.

### Dimensions

- `n_obs`: Number of time points.
- `n_state`: Number of state dimensions.
- `n_meas`: Number of measured dimensions.
- `n_particle`: Number of particles.

## Example: Brownian motion with drift

The model is
$$
\newcommand{\N}{\mathcal{N}}
\newcommand{\dt}{\Delta t}
\begin{aligned}
x_0 & \sim \pi(x_0) \\
x_t & \sim \N(x_{t-1} + \mu \dt, \sigma^2 \dt) \\
y_t & \sim \N(x_t, \tau^2).
\end{aligned}
$$

Therefore, `n_state` = `n_meas` = 1.  

Note that with $\pi(x_0) \propto 1$, we may condition on $y_0$ and obtain $x_0 \mid y_0 \sim \N(y_0, \tau^2)$.

### Using **NumPy** and **SciPy**

In [186]:
import numpy as np
import scipy as sp
import scipy.stats

def state_lpdf(x_curr, x_last, theta):
    mu = theta[0]
    sigma = theta[1]
    return sp.stats.norm.logpdf(x_curr, loc = x_last + mu * dt, scale = sigma * np.sqrt(dt))


def state_sample(x_last, theta):
    mu = theta[0]
    sigma = theta[1]
    return sp.stats.norm.rvs(loc = x_last + mu * dt, scale = sigma * np.sqrt(dt))

def meas_lpdf(y_curr, x_curr, theta):
    tau = theta[2]
    return sp.stats.norm.logpdf(y_curr, loc = x_curr, scale = tau)

def meas_sample(x_curr, theta):
    tau = theta[2]
    return sp.stats.norm.rvs(loc = x_curr, scale = tau)

# storage
#
# first do it column-major
# n_state, n_particles, n_obs
# [:,:, i_obs] represents the state of the pf  up to a given point
# [:,i_part, i_obs] is the calculation for each particle at a given point
#
# now row-major

# first a helper function
def np_intarray(dims):
    return np.reshape(np.arange(np.prod(dims))+0., dims, order = 'C')

n_meas = 1
n_state = 1
n_particles = 7
n_obs = 5
n_tot = n_state * n_particles * n_obs
#X_particles = np.arange(n_tot)
#X_particles = np.reshape(X_particles, [n_obs, n_particles, n_state], order = 'C')
X_particles = np_intarray([n_obs, n_particles, n_state])
i_obs = 0
# X_particles[i_obs]

# weights
#logw_particles = np.reshape(np.arange(n_particles * n_obs), [n_obs, n_particles])
logw_particles = np_intarray([n_obs, n_particles])
logw_particles[i_obs]

# ancestors
ant_particles = np_intarray([n_obs, n_particles]).astype(int)
# first set of particles have no ancestors
ant_particles[0] = 0

# let's try it
mu = 5
sigma = 1
tau = .1
theta = np.array([mu, sigma, tau])
dt = .1

# first simulate data
y_obs = np_intarray([n_obs, n_meas])
x_lat = np_intarray([n_obs, n_state])
x_prev = 0.
for t in range(n_obs):
    x_lat[t] = state_sample(x_prev, theta)
    y_obs[t] = meas_sample(x_lat[t], theta)
    x_prev = x_lat[t]
x_lat_true = x_lat # reuse this as a variable name
y_obs

array([[0.34141049],
       [0.74321696],
       [0.83085765],
       [1.98326492],
       [2.79380972]])

In [187]:
# now particle filter

# sample from normalized weights with replacement
def particle_resample(logw):
    mx = np.max(logw)
    wgt = np.exp(logw - mx)
    n_particles = logw.size
    #return sp.stats.multinomial.rvs(n_particles, wgt/np.sum(wgt))
    return np.random.choice(np.arange(n_particles), size = n_particles, p = wgt / np.sum(wgt))

#sp.stats.multinomial.rvs(1, [.25, .25, .25, .25], 10)
#np.random.choice(np.arange(4), size = 10, p = [.25, .25, .25, .25])

#particle_resample(np.array([0., 0., 0.]))

# using flat prior on x_0 
for i_part in range(n_particles):
    X_particles[0,i_part,:] = meas_sample(y_obs[0,:], theta)
    #logw_particles[0,i_part] = meas_lpdf(X_particles[0,i_part,:], y_obs[0], theta)
    logw_particles[0,i_part] = 0. # sample directly from posterior p(x_0 | y_0, theta)

# remaining observations
for t in range(1, n_obs):
    # resampling step
    ant_particles[t] = particle_resample(logw_particles[t-1])
    for i_part in range(n_particles):
        X_particles[t,i_part,:] = state_sample(X_particles[t-1,ant_particles[t,i_part],:], theta)
        logw_particles[t, i_part] = meas_lpdf(y_obs[t,:], X_particles[t,i_part,:], theta)

print(X_particles)
print(logw_particles)
print(ant_particles)

[[[0.25876451]
  [0.31918554]
  [0.26557652]
  [0.33072027]
  [0.38542672]
  [0.3636304 ]
  [0.42153657]]

 [[1.11239371]
  [1.29589014]
  [0.95854899]
  [1.42828306]
  [0.53992602]
  [0.63031026]
  [0.71932819]]

 [[0.70597072]
  [1.38912009]
  [1.84765476]
  [0.94279953]
  [1.72416633]
  [1.40544201]
  [1.0829206 ]]

 [[1.48140934]
  [1.75805587]
  [1.10969347]
  [1.86911294]
  [1.16602461]
  [1.93583039]
  [1.02073788]]

 [[2.30232723]
  [2.48872364]
  [2.02646744]
  [2.44285898]
  [1.9745851 ]
  [2.83355214]
  [2.41469545]]]
[[  0.           0.           0.           0.           0.
    0.           0.        ]
 [ -5.43092725 -13.88873598  -0.93474776 -22.08213157  -0.68271372
    0.74625047   1.3551129 ]
 [  0.60380926 -14.19920086 -50.31017205   0.75709729 -38.5163734
  -15.12371285  -1.79313992]
 [-11.20930491  -1.15230939 -36.77270752   0.73211276 -32.01043983
    1.2711448  -44.93926934]
 [-10.69410502  -3.27022919 -28.05706219  -4.77467435 -32.17280205
    1.30467354  -5.8027

In [163]:
ant_particles

array([[0, 0, 0, 0, 0, 0, 0],
       [2, 0, 1, 1, 1, 1, 1],
       [0, 0, 2, 0, 0, 4, 1],
       [0, 0, 0, 6, 1, 0, 0],
       [0, 0, 0, 7, 0, 0, 0]])