# An efficient Gumbel-max sampler

Among random samplers for high-dimensional categorical distributions the Gumbel-max trick is the most stable in the sense that, fixing randomness, changes to its input parameters lead to the fewest changes to its samples. This property is desirable in systems leveraging approximate floating point numerics, where fully determinism is too costly but deterministic results are desired. However the Gumbel-max trick is much more expensive than other simpler samplers, e.g. whereas a naive inverse-cdf sampler costs only a single random number generator call, a naive Gumbel-max trick requires `num_categories`-many random number generator calls.

This notebook implements a Gumbel-max sampler that is stable but consumes much less randomness, I believe something like `O(perplexity)`-many random number generator calls in expectation. This is of practical importance in e.g. LLM serving.

In [1]:
from collections import Counter
import contextlib
import numpy as np
import scipy.stats

In [2]:
counter = Counter()

@contextlib.contextmanager
def counting():
    global counte
    counter.clear()
    yield
    for key, value in sorted(counter.items()):
        print(f"{key: >9}: {value: >5}")

In [3]:
def categorical_sample(logits: np.ndarray) -> int:
    counter["exp"] += len(logits)
    counter["uniform"] += 1
    probs = np.exp(logits - np.max(logits))
    probs /= np.sum(probs)
    return np.random.choice(len(logits), p=probs)

In [4]:
def gumbel_max_naive(logits: np.ndarray) -> int:
    counter["uniform"] += 1
    counter["log"] += 2 * len(logits)
    return int(np.argmax(logits + np.random.gumbel(size=logits.shape)))

In [5]:
LOG_2 = np.log(2)

def gumbel_icdf(u):
    counter["log"] += 2
    u = np.clip(u, np.finfo(type(u)).tiny, 1. - np.finfo(type(u)).eps)
    return -np.log(-np.log(u))

def gumbel_cdf(x):
    counter["exp"] += 2
    return np.exp(-np.exp(-x))

def uniform():
    counter["uniform"] += 1
    return np.random.uniform()

def bernoulli():
    counter["bernoulli"] += 1
    return np.random.uniform() < 0.5

def gumbel_max_recursive(logits: np.ndarray) -> int:
    N = len(logits)
    if N & (N - 1) != 0:
        raise NotImplementedError('N must be a power of 2')
    # Consider a binary tree of N leaves, with internal nodes at addresses.
    tree_height = N.bit_length() - 1
    root_address = 1
    def get_height(address: int) -> int:
        return tree_height - (address.bit_length() - 1)
    assert get_height(root_address) == tree_height
    assert get_height(N) == 0

    # Pass 1: Compute recursive maxima bottom-up.
    maxima = np.full(2 * N, np.nan, dtype=logits.dtype)
    maxima[N: 2 * N] = logits
    def _update_maxima(size: int) -> None:
        if size > 1:
            x = maxima[size: 2 * size: 2]
            y = maxima[size + 1: 2 * size: 2]
            maxima[size // 2: size] = np.maximum(x, y)
            _update_maxima(size // 2)
    _update_maxima(N)

    # Pass 2: Sample recursively top-down.
    uniforms = np.full(2 * N, np.nan, dtype=logits.dtype)
    gumbels = np.full(2 * N, np.nan, dtype=logits.dtype)
    i_best = -1
    logit_best = -np.inf
    def _sample(address: int) -> None:
        nonlocal i_best, logit_best
        height = get_height(address)
        if height == 0:  # Leaf.
            g = gumbels[address]
            x = logits[address - N] + g
            if x > logit_best:
                i_best = address - N
                logit_best = x
        else:  # Internal node.
            left = 2 * address
            right = 2 * address + 1
            # Sample children, transforming independent random variables
            #   g_parent ~ Gumbel(log(2) * (height + 1))
            #   u_new ~ Uniform(0, 1)
            # into iid random variables
            #   g_left, g_right ~ Gumbel(log(2) * height)
            u = uniforms[address]
            g = gumbels[address]
            upper, lower = (left, right) if bernoulli() else (right, left)
            gumbels[upper] = g
            u = uniforms[upper] = gumbel_cdf(g - LOG_2 * height)
            u = uniforms[lower] = u * uniform()  # truncated sample
            gumbels[lower] = gumbel_icdf(u) + LOG_2 * height
            assert gumbels[lower] <= gumbels[upper]
            # Recurse, conditional on upper bounds.
            if maxima[left] + gumbels[left] > logit_best:
                _sample(left)
            if maxima[right] + gumbels[right] > logit_best:
                _sample(right)
    # Sample root g ~ Gumbel(log(2) * tree_height).
    u = uniforms[root_address] = uniform()
    gumbels[root_address] = gumbel_icdf(u) + LOG_2 * tree_height
    _sample(root_address)

    assert i_best >= 0
    return i_best

In [6]:
logits = np.random.randn(1 << 17)  # approximately llama 3
with counting():
    print(gumbel_max_recursive(logits))

102532
bernoulli:   409
      exp:   818
      log:   820
  uniform:   410


In [7]:
def goodness_of_fit(logits: np.ndarray, counts: np.ndarray) -> float:
    probs = np.exp(logits - np.max(logits))
    probs /= np.sum(probs)
    # Use Pearson's chi-squared test.
    count = np.sum(counts)
    var = probs * (1 - probs) / count
    z = counts / count - probs
    chi2 = np.dot(z, z / var)
    p_value = 1 - scipy.stats.chi2.cdf(chi2, len(counts) - 1)
    return p_value

In [8]:
def test_gof(sampler, logits, n_samples=1000):
    counts = np.zeros_like(logits)
    for _ in range(n_samples):
        counts[sampler(logits)] += 1
    gof = goodness_of_fit(logits, counts)
    print(f"{sampler.__name__}: {gof:.3f}")
    assert gof > 0.01, f"{sampler.__name__} has bad goodness of fit"

In [9]:
%%time
logits = np.random.randn(16)
test_gof(categorical_sample, logits)
test_gof(gumbel_max_naive, logits)
test_gof(gumbel_max_recursive, logits)

categorical_sample: 0.411
gumbel_max_naive: 0.902
gumbel_max_recursive: 0.000


AssertionError: gumbel_max_recursive has bad goodness of fit