In [64]:
from numpy.random import multinomial
import numpy as np

In [65]:
N = 10 
C = 3

current_pop = np.random.randint(0, 100, size=N)
prob = np.zeros((C, N), dtype=np.float64)
print(current_pop)

[ 7 87 79 85  1 28 50 62  3 20]


In [66]:
prob[1] = 1 - np.sum(prob, axis=0)

In [67]:
multinomial(current_pop[0], prob.T[0])

array([0, 7, 0])

In [68]:
def chain_multinomial(n, p, stay_idx, rng=None, check_tol=1e-12):
    """
    Competing-risks multinomial with an explicit 'stay' category. 
    The function assumes that the probability of staying is 0 (i.e., p does not include the probability of staying).

    Args:
        n (int): Number of individuals in the source compartment.
        p (array-like): Length-K vector. For k != stay_idx, p[k] = r_k * dt.
        stay_idx (int): Index of the 'stay' category.
        rng : np.random.Generator or seed or None. If None, uses default_rng().

    Returns:
        counts (np.ndarray of int, shape (K,)): Realized counts for each category (including 'stay'), summing to n.
    """
    p = np.asarray(p, dtype=float)
    K = p.size
    if not (0 <= stay_idx < K):
        raise IndexError("stay_idx out of bounds.")
    if np.any(p < -1e-16):
        raise ValueError("All entries in p must be non-negative.")

    # Total hazard H = sum over all entries (stay entry should be 0)
    H = float(p.sum())

    rng = np.random.default_rng(rng)
    counts = np.zeros(K, dtype=int)

    if H <= check_tol:
        # No hazard -> everyone stays
        counts[stay_idx] = n
        return counts

    # Probability of leaving given total hazard H
    p_leave = -np.expm1(-H)         
    p_leave = min(max(p_leave, 0.0), 1.0)  # clamp for floating-point safety

    # Draw total exits
    exits = rng.binomial(n, p_leave)
    stays = n - exits

    if exits == 0:
        counts[stay_idx] = n
        return counts

    # Split exits proportionally to hazards (stay has zero weight, so gets 0)
    q = p / H
    alloc = rng.multinomial(exits, q)  # alloc.sum() == exits, alloc[stay_idx] == 0

    # Fill output: allocated exits + residual stays
    counts[:] = alloc
    counts[stay_idx] = stays  # overwrite (alloc[stay_idx] is 0 anyway)

    return counts


In [69]:
n = 100 
p = [0.1, 0., 0.2]
stay_idx = 1


In [None]:
%timeit chain_multinomial(n, p, stay_idx)

71.1 µs ± 6.72 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [71]:
%timeit multinomial(n, p)

1.48 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [79]:
chain_multinomial(n, p, stay_idx)

array([ 8, 72, 20])

In [77]:
p

[0.1, 0.0, 0.2]