# Organize code to give to JAX team

In [1]:
import random
class ListDict(object):
    """  
    Solution adapted from 
    https://stackoverflow.com/questions/15993447/python-data-structure-for-efficient-add-remove-and-random-choice
    Data structure with efficient
    (i)   lookup
    (ii)  uniform random selection
    (iii) removal

    """
    def __init__(self, nParticles):
        self.item_to_position = {}
        self.items = []
        for n in range(nParticles):
            self.add_item(n)

    def add_item(self, item):
        if item in self.item_to_position:
            return
        self.items.append(item)
        self.item_to_position[item] = len(self.items)-1

    def remove_item(self, item):
        position = self.item_to_position.pop(item)
        last_item = self.items.pop()
        if position != len(self.items):
            self.items[position] = last_item
            self.item_to_position[last_item] = position

    def choose_random_item(self):
        return random.choice(self.items)

    def __contains__(self, item):
        return item in self.item_to_position

    def __iter__(self):
        return iter(self.items)

    def __len__(self):
        return len(self.items)

In [11]:
import numpy as np

nParticles = 500
Lambda = np.random.uniform(-100, 100, nParticles)
def birthDeathJumpIndicies(Lambda, stepsize=0.01):
    nParticles = Lambda.shape[0]
    alive = ListDict(nParticles)
    r = np.random.uniform(low=0, high=1, size=nParticles)
    threshold = r < 1 - np.exp(-np.abs(Lambda) * stepsize)
    idxs = np.argwhere(threshold)[:, 0]
    np.random.shuffle(idxs)

    # Particle jumps
    output = np.arange(nParticles)
    for i in idxs:
        if i in alive:
            j = alive.choose_random_item()
            if Lambda[i] > 0:
                output[i] = j
                alive.remove_item(i)
            elif Lambda[i] < 0:
                output[j] = i 
                alive.remove_item(j)

    return output

print(birthDeathJumpIndicies(Lambda))

In [12]:
a

array([  0,  48,   2,   3, 213, 191,  31,   7,  39, 153,  10,  11, 452,
        13,   7,  15,  16,  17, 386, 149, 330,  21,  22,  23,  24,  25,
        26,  13,  28,  29,  30,  31,  32,  33, 325,  35,  36, 248,  38,
        39, 209,  41,  42, 120,  29,  45,  46,  47,  48,  49,  50, 385,
        52,  53,  54, 398,  56, 409, 282, 342, 258,  61,  23, 472,  64,
        65,  66, 323,  68,  69,  70,  71, 159,  73,  74,  75,  76,  77,
       143, 161, 392,  81,  82,  83,  84, 354,  86,   3,  88, 336, 254,
       335,  11, 484,  94,  95, 329,  97, 267, 147, 100, 101, 102, 302,
       104, 105, 106, 107, 108, 109, 110, 291, 112, 113, 242, 115, 116,
       396, 118, 119, 120, 121, 122, 123, 482, 125, 457, 147, 223, 129,
       130, 490, 132, 133, 134, 481, 332, 137, 138, 139, 140,  45, 142,
       143, 144, 145,  68, 147, 148, 149, 150, 152,  23, 153, 154,  65,
       156, 157,  97, 159, 337, 161, 162, 261, 164, 165, 166, 183, 168,
       169, 170, 148, 150, 202, 174, 161, 455, 177, 467, 179, 18