# Benchmarking CPU based birth-death

In [2]:
import random
class ListDict(object):
    """  
    Solution adapted from 
    https://stackoverflow.com/questions/15993447/python-data-structure-for-efficient-add-remove-and-random-choice
    Data structure with efficient
    (i)   lookup
    (ii)  uniform random selection
    (iii) removal

    """
    def __init__(self, nParticles):
        self.item_to_position = {}
        self.items = []
        for n in range(nParticles):
            self.add_item(n)

    def add_item(self, item):
        if item in self.item_to_position:
            return
        self.items.append(item)
        self.item_to_position[item] = len(self.items)-1

    def remove_item(self, item):
        position = self.item_to_position.pop(item)
        last_item = self.items.pop()
        if position != len(self.items):
            self.items[position] = last_item
            self.item_to_position[last_item] = position

    def choose_random_item(self):
        return random.choice(self.items)

    def __contains__(self, item):
        return item in self.item_to_position

    def __iter__(self):
        return iter(self.items)

    def __len__(self):
        return len(self.items)

In [3]:
import numpy as np

nParticles = 500
Lambda = np.random.uniform(-100, 100, nParticles)
def birthDeathJumpIndicies(Lambda, stepsize=0.01):
    nParticles = Lambda.shape[0]
    alive = ListDict(nParticles)
    r = np.random.uniform(low=0, high=1, size=nParticles)
    threshold = r < 1 - np.exp(-np.abs(Lambda) * stepsize)
    idxs = np.argwhere(threshold)[:, 0]
    np.random.shuffle(idxs)

    # Particle jumps
    output = np.arange(nParticles)
    for i in idxs:
        if i in alive:
            j = alive.choose_random_item()
            if Lambda[i] > 0:
                output[i] = j
                alive.remove_item(i)
            elif Lambda[i] < 0:
                output[j] = i 
                alive.remove_item(j)

    return output


In [4]:
# The first thing we want to do is test the CPU based code
%timeit birthDeathJumpIndicies(Lambda)
a = birthDeathJumpIndicies(Lambda)

148 µs ± 552 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [5]:
a

array([256,   1, 113,   3,   4,   5,   6, 278,   8,   9, 399, 184,  12,
        13,  14,  15, 317, 207,  18,  19,  20,  21,  22,  23,  24, 430,
        26,  27, 421, 104,  30,  31, 305,  33, 107,  35,  36,  37,  38,
       153, 289,  41, 191, 293, 335,  45,  46,  47,  48,  49, 416, 181,
        52,  53,  54,  55,  56,  57, 459,  59,  60,  61, 221,  63,  64,
        65,  66,  67,  68,  69, 222, 183, 161,  73,  74, 464, 309,  77,
       183,  79, 209,  81,  82,  83,  84,  85,  86,  87,  88, 427,  90,
        53,  92,  93,  94, 214,  96,  97, 343,  99, 109, 101,  59, 313,
       104, 105, 106, 107, 108, 109, 144, 111, 112, 113, 114, 115, 116,
       116,  45, 119, 128, 121, 122, 123, 436,  60, 126, 389, 128, 100,
       140, 299, 280, 133, 466, 135, 136, 137, 253, 139, 140, 141,  49,
       135, 144, 145, 146, 272,   8, 247, 150, 151, 152, 153, 154, 155,
        36, 157, 374, 110, 160, 161, 162, 163, 164, 165, 166, 167, 310,
       169, 170, 483, 193, 173,  46, 399, 176, 177, 220, 179, 18

# Benchmarking GPU based birth-death

In [6]:
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)

def true_fun(i, j, jumps):
    jumps = jumps.at[0,i].set(j)
    jumps = jumps.at[1,i].set(0)
    return jumps

def false_fun(i, j, jumps):
    jumps = jumps.at[0,j].set(i) 
    jumps = jumps.at[1,j].set(0)
    return jumps

def scan_func(carry, x):
    key, jumps, Lambda = carry 
    pred = Lambda[x] > 0
    key, subkey = jax.random.split(key)
    j = jax.random.choice(key, Lambda.shape[0], p=jumps[1]) # Not choosing dead particles
    # j = jax.random.choice(key, Lambda.shape[0]) # TODO this is for debugging purposes. see if this is causing the slowdown!
    jumps = jax.lax.cond(x != -1, lambda: jax.lax.cond(pred, true_fun, false_fun, *(x, j, jumps)), lambda: jumps)
    return (key, jumps, Lambda), jumps

@jax.jit
def birth_death_jumps(key, Lambda, stepsize=0.01):
    nParticles = Lambda.shape[0]
    r = jax.random.uniform(minval=0, maxval=1, shape=Lambda.shape, key=key)
    threshold = r < 1 - jnp.exp(-jnp.abs(Lambda) * stepsize)
    idxs = jax.random.permutation(key, jnp.argwhere(threshold, size=nParticles, fill_value=-1).squeeze())

    jumps = jnp.zeros((2, nParticles), dtype=int)
    jumps = jumps.at[0].set(jnp.arange(nParticles))
    jumps = jumps.at[1].set(jnp.ones(nParticles, dtype=int))

    init = (key, jumps, Lambda)
    jumps = jax.lax.scan(scan_func, init, idxs)

    return jumps[0][1][0]

In [7]:
# Benchmark GPU based code
%timeit birth_death_jumps(key, Lambda)
a = birth_death_jumps(key, Lambda)

1.13 s ± 16.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Loop stride test

In [2]:
import jax
from functools import partial

def ula_kernel(key, param, iteration):

    key, subkey = jax.random.split(key)

    iteration = iteration + 1

    return key, 0, iteration


@partial(jax.jit, static_argnums=(1,2,3))
def ula_sampler_full_jax_jit(key):

    # @progress_bar_scan(n_iter)
    def ula_step(carry, x):
        key, param, iteration = carry
        key, param, iteration = ula_kernel(key, param, iteration)
        return (key, param, iteration), param

    carry = (key, x_0, 0)
    _, samples = jax.lax.scan(ula_step, carry, None, n_iter)
    return samples

In [None]:
ula_sampler_full_jax_jit