In [1]:
import random
def sample(probs):
    """
    Sample from a discrete distribution.
    :param probs: a list of probabilities, which must sum to 1
    :return: an index sampled from the distribution
    """

    assert sum(probs) == 1
    # all non-negative
    assert all(p >= 0 for p in probs)

    # pick a random number between 0 and 1
    r = random.random()
    # find the first index where the cumulative probability exceeds r
    p_less = 0
    for i, p in enumerate(probs):
        p_less += p
        if r < p_less:
            return i
    return len(probs) - 1

# demo
probs = [0.1, 0.2, 0.3, 0.4]
counts = [0, 0, 0, 0]
for i in range(10000):
    counts[sample(probs)] += 1
print(counts)


[998, 1968, 3017, 4017]
