In [11]:
import numpy as np
import heapq
from scipy.stats import gamma

class BranchingProcess:
    def __init__(self, T=1.0, alpha=0.2, d=5, x0=None, kappa=0.5, theta=2.5, n=96000, M=1000):
        self.T = T
        self.alpha = alpha
        self.d = d
        self.sigma0 = 1 / np.sqrt(d)
        if x0 is None:
            self.x0 = np.full(d, 0.5)
        else:
            self.x0 = np.array(x0)
        self.kappa = kappa
        self.theta = theta
        self.n = n
        self.M = M
        self.reset()

    def reset(self):
        self.time = 0.0
        # self.population_alive = {(1,): np.zeros(self.d, dtype = np.float32) }

        self.next_splitting_time = []
        split_time = min(np.random.gamma(shape = self.kappa, scale = self.theta), self.T)
        heapq.heappush(self.next_splitting_time, (split_time, (1,)))
        
        # population dict keeps track of data of all particles
        # key is particle lable, value is (mark, birth state X_{T_k-}, death state X_{T_k}, birth time T_k-, death time T_k)
        self.population = {(1,): (
            0, 
            self.x0, 
            self.x0 + np.sqrt(split_time)*self.sigma0*np.random.normal(size=self.d), 
            0.0, 
            split_time
            )}
        # self.population_alive_series = [self.population_alive.copy()]

    def step(self):
        (tau, label) = heapq.heappop(self.next_splitting_time)
        tau = min(tau, self.T)
        self.time = tau
        if tau >= self.T:
            return
        # for l in self.population_alive:
        #     self.population_alive[l] = self.population_alive[l] + np.sqrt(tau - self.time)*np.random.normal()
        i = np.random.binomial(n = 1, p = 0.5)
        # particle dies
        # x = self.population_alive.pop(label)
        if i == 1:
            l1,l2 = label + (1,), label + (2,)
            x = self.population[label][2]
            s1,s2 = np.random.gamma(shape = self.kappa, scale = self.theta, size = 2)
            s1,s2 = min(s1 + tau, self.T), min(s2 + tau,self.T) 
            heapq.heappush(self.next_splitting_time, (s1, l1))
            heapq.heappush(self.next_splitting_time, (s2, l2))
            self.population[label + (1,)] = (0, x, x + np.sqrt(s1 - tau)*self.sigma0*np.random.normal(size = self.d), tau, s1)
            self.population[label + (2,)] = (1, x, x + np.sqrt(s2 - tau)*self.sigma0*np.random.normal(size = self.d), tau, s2)                    

    def sample_process(self):
        self.reset()
        while self.time < self.T and len(self.next_splitting_time) > 0:
            self.step()

    def sample_expectation_psi(self):
        for _ in range(self.M):
            self.reset()
            for _ in range(self.n):
                dt = np.random.exponential(scale=1 / np.sum(self.population))
                self.step(dt)


In [21]:
bp = BranchingProcess()

bp.sample_psi()

{(1,): (0, array([0.5, 0.5, 0.5, 0.5, 0.5]), array([0.93226583, 0.36668135, 0.44105715, 0.29713114, 0.54938209]), 0.0, 0.40531445131694765), (1, 1): (0, array([0.93226583, 0.36668135, 0.44105715, 0.29713114, 0.54938209]), array([1.49943025, 1.21632229, 0.94748254, 0.12483685, 0.32339806]), 0.40531445131694765, 1.0), (1, 2): (1, array([0.93226583, 0.36668135, 0.44105715, 0.29713114, 0.54938209]), array([0.88496852, 0.39886124, 0.40206394, 0.48832135, 0.50209867]), 0.40531445131694765, 0.5304172741839053), (1, 2, 1): (0, array([0.88496852, 0.39886124, 0.40206394, 0.48832135, 0.50209867]), array([0.78432305, 0.16703611, 0.39147257, 0.42554991, 0.57956546]), 0.5304172741839053, 0.8788331859371037), (1, 2, 2): (1, array([0.88496852, 0.39886124, 0.40206394, 0.48832135, 0.50209867]), array([0.90803518, 0.25612803, 0.52547719, 0.4055137 , 0.38829999]), 0.5304172741839053, 0.5775470930642472), (1, 2, 2, 1): (0, array([0.90803518, 0.25612803, 0.52547719, 0.4055137 , 0.38829999]), array([0.902624