# Parallelizing the Particle Filter

**Martin Lysy, University of Waterloo**

**April 11, 2022**

## Problem Statement

Various steps in the particle filter are parallelizable across particles.  This notebook contains various experiments in how to do this most effectively.

In general, we'll need to specify `n_devices` and `n_particles_per_device`.  

In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [2]:
import functools
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
from jax import lax
import pfjax as pf
from pfjax.models import BMModel
from pfjax.particle_filter import _lweight_to_prob, _tree_add, _tree_mean, _tree_zeros, _rm_keys
jax.devices()



[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

## Pmap/Vmap Version

The first attempt uses only the non-experimental `jax.pmap()`.  More specifically, we `jax.vmap()` over `n_particles_per_device` on each of the `n_devices` sent to `pmap()`.

**Notes:**

- `pmap()` jits by default.  It would be nice to prevent this , so that things can be jitted at a higher level, e.g., if `particle_filter()` is called inside another function.

    In fact, seems that `pmap()` within `jit` will destroy the `ShardedDeviceArray`.  Since `lax.scan()` jits automatically this is also the case there (though currently the sharded array is being destroyed by hand).    So perhaps we need to write a custom `resample_multinomial()` that can deal with shards.

In [3]:
def particle_filter_pmap(model, key, y_meas, theta,
                         n_devices,
                         n_particles_per_device,
                         particle_sampler=pf.particle_resample,
                         history=False,
                         accumulator=None):
    """
    Apply particle filter for given value of `theta`.

    Full documentation in pfjax package.
    """
    n_obs = y_meas.shape[0]
    n_particles = n_devices * n_particles_per_device
    has_acc = accumulator is not None

    # internal functions for vectorizing
    def pf_step(key, x_prev, y_curr):
        return model.pf_step(key=key, x_prev=x_prev, y_curr=y_curr, theta=theta)

    def pf_init(key):
        return model.pf_init(key=key, y_init=y_meas[0], theta=theta)

    def pf_acc(acc_prev, x_prev, x_curr, y_curr):
        return _tree_add(
            tree1=acc_prev,
            tree2=accumulator(
                x_prev=x_prev, x_curr=x_curr, y_curr=y_curr, theta=theta
            )
        )

    # reshape first two dimensions to one dimension
    def reshape_1d(x):
        return x.reshape((-1,) + x.shape[2:])

    # reshape first dimension into n_devices x n_particles_per_device
    def reshape_2d(x):
        return x.reshape((n_devices, n_particles_per_device) + x.shape[1:])

    # lax.scan setup
    # scan function
    def filter_step(carry, t):
        # sample particles from previous time point
        key, subkey = random.split(carry["key"])
        new_particles = particle_sampler(
            key=subkey,
            x_particles_prev=reshape_1d(carry["x_particles"]),
            logw=reshape_1d(carry["logw"])
        )
        # update particles to current time point (and get weights)
        key, *subkeys = random.split(key, num=n_particles+1)
        x_particles, logw = jax.pmap(
            jax.vmap(pf_step, in_axes=(0, 0, None)),
            in_axes=(0, 0, None)
        )(reshape_2d(jnp.array(subkeys)),
          reshape_2d(new_particles["x_particles"]),
          y_meas[t])
        if has_acc:
            # accumulate expectation
            acc_curr = jax.pmap(
                jax.vmap(pf_acc, in_axes=(0, 0, 0, None)),
                in_axes=(0, 0, 0, None)
            )(carry["accumulate_out"], 
              reshape_2d(new_particles["x_particles"]),
              x_particles, 
              y_meas[t])
        # output
        res_carry = {
            "x_particles": x_particles,
            "logw": logw,
            "key": key,
            "loglik": carry["loglik"] + jsp.special.logsumexp(logw),
            "resample_out": _rm_keys(new_particles, ["x_particles", "logw"])
        }
        if has_acc:
            res_carry["accumulate_out"] = acc_curr
        res_stack = _rm_keys(res_carry, ["key", "loglik"]) if history else None
        return res_carry, res_stack
    # scan initial value
    key, *subkeys = random.split(key, num=n_particles+1)
    x_particles, logw = jax.pmap(
        jax.vmap(pf_init)
    )(reshape_2d(jnp.array(subkeys)))
    # dummy initialization for resample
    init_resample = particle_sampler(
        key=key,
        x_particles_prev=reshape_1d(x_particles),
        logw=reshape_1d(logw)
    )
    init_resample = _rm_keys(init_resample, ["x_particles", "logw"])
    init_resample = _tree_zeros(init_resample)
    if has_acc:
        # dummy initialization for accumulate
        init_acc = jax.pmap(
            jax.vmap(accumulator, in_axes=(0, 0, 0, None, None)),
            in_axes=(0, 0, None, None)
        )(x_particles, x_particles, y_meas[0], theta)
        init_acc = _tree_zeros(init_acc)
    filter_init = {
        "x_particles": x_particles,
        "logw": logw,
        "loglik": jsp.special.logsumexp(logw),
        "key": key,
        "resample_out": init_resample
    }
    if has_acc:
        filter_init["accumulate_out"] = init_acc
    # lax.scan itself
    last, full = lax.scan(filter_step, filter_init, jnp.arange(1, n_obs))
    if history:
        # append initial values of x_particles and logw
        full["x_particles"] = jnp.concatenate([
            filter_init["x_particles"][None], full["x_particles"]
        ])
        full["logw"] = jnp.concatenate([
            filter_init["logw"][None], full["logw"]
        ])
    else:
        full = last
        if has_acc:
            # weighted average of accumulated values
            full["accumulate_out"] = _tree_mean(
                tree=full["accumulate_out"],
                logw=full["logw"]
            )
    # calculate loglikelihood
    full["loglik"] = last["loglik"] - n_obs * jnp.log(n_particles)
    return full

In [4]:
# generate data
key = random.PRNGKey(0)
# parameter values
mu = 5
sigma = 1
tau = .1
theta = jnp.array([mu, sigma, tau])
# data specification
dt = .1
n_obs = 5
x_init = jnp.array(0.)
bm_model = BMModel(dt=dt)
# simulate without for-loop
y_meas, x_state = pf.simulate(bm_model, key, n_obs, x_init, theta)

In [22]:
# particle filter specification
n_devices = 2
n_particles_per_device = 5
n_particles = n_devices * n_particles_per_device

pf_serial = jax.jit(functools.partial(pf.particle_filter2,
                                      model=bm_model, y_meas=y_meas,
                                      n_particles=n_particles, history=True))

pf_pmap = jax.jit(functools.partial(particle_filter_pmap,
                                    model=bm_model, y_meas=y_meas, n_devices=n_devices,
                                    n_particles_per_device=n_particles_per_device, history=True))

pf_out = pf_serial(theta=theta, key=key)

pf_out2 = pf_pmap(theta=theta, key=key)

In [23]:
pf_out["x_particles"].ravel() - pf_out2["x_particles"].ravel()

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0.], dtype=float64)

In [24]:
%timeit pf_serial(theta=theta, key=key)
%timeit pf_pmap(theta=theta, key=key)

24.5 µs ± 248 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
289 µs ± 18.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Scratch

In [None]:
def convolve(x, w):
    """
    Convolve x with w.
    
    Must have `len(x) >= len(w)`.
    """
    output = []
    for i in jnp.arange(0, len(x)-len(w)+1):
        output.append(jnp.dot(x[i:i+len(w)], w))
    return jnp.array(output)

jconvolve = jax.jit(convolve) # jitted version

# test from doc
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

convolve(x, w)

## Timings

In [None]:
key = jax.random.PRNGKey(0)
nx = 200
nw = 100

x = jax.random.normal(key, (nx,))
w = jax.random.normal(key, (nw,))

In [None]:
# unjitted
%timeit convolve(x,w)

In [None]:
# jitted
%timeit jconvolve(x,w)