# Gumbel Machinery

[Gumbel Machinery](https://cmaddis.github.io/gumbel-machinery)

## Basic

In [1]:
import numpy as np

alpha = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
uniform = np.random.rand(5)
gumbels = -np.log(-np.log(uniform)) + np.log(alpha)
K = np.argmax(gumbels)

In [2]:
K

1

## Top-down, truncated

In [4]:
def truncated_gumbel(alpha, truncation):
    gumbel = np.random.gumbel() + np.log(alpha)
    return -np.log(np.exp(-gumbel) + np.exp(-truncation))

def topdown(alphas, k):
    topgumbel = np.random.gumbel() + np.log(sum(alphas))
    gumbels = []
    for i in range(len(alphas)):
        if i == k:
            gumbel = topgumbel
        else:
            gumbel = truncated_gumbel(alphas[i], topgumbel)
        gumbels.append(gumbel)
    return gumbels

In [5]:
topdown([1.0, 4.0, 6.0, 4.0, 1.0], 1)

[0.4009548031127406,
 3.810074197981101,
 2.091188681819038,
 1.4685647536099093,
 0.26880406626074643]

In [6]:
topdown([1.0, 4.0, 6.0, 4.0, 1.0], 1)

[1.4446002068406625,
 1.9673427704374287,
 1.0156692882895622,
 0.8749062964793195,
 -0.232073276072848]

In [8]:
topdown([1.0, 4.0, 6.0, 4.0, 1.0], 1)

[-0.4130987101324357,
 2.069408069714429,
 0.36218932712618457,
 0.8769883312748507,
 -1.0999703389276922]

## Rejection

In [9]:
def rejection(alphas, k):
    log_alphas = np.log(alphas)
    gumbels = np.random.gumbel(size=len(alphas))
    while k != np.argmax(gumbels + log_alphas):
        gumbels = np.random.gumbel(size=len(alphas))
    return (gumbels + log_alphas).tolist()

In [10]:
rejection([1.0, 4.0, 6.0, 4.0, 1.0], 1)

[-0.08321533934696304,
 3.2248223834948964,
 0.9705408559121248,
 2.2549412894805028,
 2.7198182298543196]

In [11]:
rejection([1.0, 4.0, 6.0, 4.0, 1.0], 1)

[0.6115495781386859,
 2.039384649245685,
 1.577156303506922,
 1.474300190441932,
 0.5587199074250211]