# Gumbel-Max Trick

The Gumbel-Max trick is an approach to sample values from a categorical variable. Assume we have a categorical variable $X$  with $k$ values and $\theta_X = \{ p_1, p_2, \ldots, p_k \}$ defines the probabilities of each value (note $\sum_k p_k = 1$). Then we can sample values from $X$ using the Guble-Max trick as follows.

- for each $p_k$ simulate $z_k ~ \mathcal{G}(0, 1)$ and set $g_k = p_k + z_k$
- return $\arg\max_k g_k$, where $k$ points to the k-th value

The [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution) models a distribution of maximum or minium values. For example, it could be used to model the distribution of the maximum values of a river. In multinomial logistic regression, the errors of the latent variables follow a Gumbel distribution.

## Simple example

In [1]:
import numpy as np
from numpy.random import gumbel

x = np.array([0.5, 0.5])
z = gumbel(loc=0, scale=1, size=x.shape)

x, z

(array([0.5, 0.5]), array([0.98094889, 0.74723253]))

In [2]:
x + z

array([1.48094889, 1.24723253])

In [3]:
(x + z).argmax()

0

## Sample

In [4]:
samples = np.argmax(np.array([x + gumbel(0, 1, x.shape) for _ in range(10_000)]), axis=1)
samples

array([1, 1, 0, ..., 1, 1, 0])

In [5]:
_, counts = np.unique(samples, return_counts=True)

In [6]:
counts

array([4958, 5042])

In [7]:
counts / np.sum(counts)

array([0.4958, 0.5042])

As you can see, the proportion of sampled values matches closely with the true proportions.

## Sample, keep track of priors

We are going to perform sampling using the Gumbel-Max trick. As we sample, we will keep track of the prior `g`.

In [8]:
import pandas as pd

x = np.array([0.8, 0.2])

s = pd.DataFrame([(x, gumbel(0, 1, x.shape)) for _ in range(10_000)], columns=['p', 'g']) \
    .assign(z=lambda d: d['p'] + d['g']) \
    .assign(j=lambda d: d['z'].apply(np.argmax))

s.head(10)

Unnamed: 0,p,g,z,j
0,"[0.8, 0.2]","[-0.11362697649369283, -0.9736320492222224]","[0.6863730235063072, -0.7736320492222224]",0
1,"[0.8, 0.2]","[-0.610495978777585, 0.8084945646445472]","[0.18950402122241505, 1.0084945646445471]",1
2,"[0.8, 0.2]","[0.2654703065780095, 0.20365001204867103]","[1.0654703065780096, 0.403650012048671]",0
3,"[0.8, 0.2]","[1.5661774297412336, 0.6892010308912532]","[2.3661774297412337, 0.8892010308912532]",0
4,"[0.8, 0.2]","[2.33190284920853, -0.5669642129060366]","[3.1319028492085303, -0.3669642129060366]",0
5,"[0.8, 0.2]","[-0.3292231899985446, 2.4961360972384687]","[0.47077681000145544, 2.696136097238469]",1
6,"[0.8, 0.2]","[0.18872523946302888, -0.14890675021054955]","[0.9887252394630289, 0.05109324978945046]",0
7,"[0.8, 0.2]","[6.185039489501733, 4.965461905062004]","[6.985039489501733, 5.1654619050620045]",0
8,"[0.8, 0.2]","[1.7302317395733458, 3.031986574399711]","[2.530231739573346, 3.2319865743997114]",1
9,"[0.8, 0.2]","[-0.4012850276990852, -0.36008669125316056]","[0.39871497230091485, -0.16008669125316055]",0


These are the averaged values of location and scale.

- $g_0 = 1.82 \pm 1.3$
- $g_1 = 1.81 \pm 1.3$

In [9]:
pd.DataFrame(np.array(s['z'].tolist()), columns=['g_0', 'g_1']) \
    .join(s[['j']]) \
    .query('j==0')['g_0'].mean()

1.8405172612830005

In [10]:
pd.DataFrame(np.array(s['z'].tolist()), columns=['g_0', 'g_1']) \
    .join(s[['j']]) \
    .query('j==1')['g_1'].mean()

1.7656388761849071

In [11]:
pd.DataFrame(np.array(s['z'].tolist()), columns=['g_0', 'g_1']) \
    .join(s[['j']]) \
    .query('j==0')['g_0'].std()

1.3105895776675982

In [12]:
pd.DataFrame(np.array(s['z'].tolist()), columns=['g_0', 'g_1']) \
    .join(s[['j']]) \
    .query('j==1')['g_1'].std()

1.2571968756132792

## Inferring P(U|e)

Let's say $e = \{X=0\}$, then let's use sampling with rejection to compute $P(U|e)$.

In [13]:
U = pd.DataFrame([(x, gumbel(0, 1, x.shape)) for _ in range(10_000)], columns=['p', 'g']) \
    .assign(z=lambda d: d['p'] + d['g']) \
    .assign(j=lambda d: d['z'].apply(np.argmax)) \
    .assign(r=lambda d: (d['j'] != 0).astype(int)) \
    .query('r == 0') \
    .assign(z_0=lambda d: d['z'].apply(lambda v: v[0])) \
    .assign(z_1=lambda d: d['z'].apply(lambda v: v[1]))[['z_0', 'z_1']] \
    .describe().loc[['mean', 'std']]

U

Unnamed: 0,z_0,z_1
mean,1.816715,0.230677
std,1.293434,0.85721


In [14]:
U.loc['mean'].z_0, U.loc['mean'].z_1, U.loc['std'].z_0, U.loc['std'].z_0

(1.8167146034170647,
 0.23067720926502758,
 1.2934344579794417,
 1.2934344579794417)

In [15]:
U.loc['mean'] - 0.57721

z_0    1.239505
z_1   -0.346533
Name: mean, dtype: float64

## References

- [Counterfactual Policy Introspection using Structural Causal Models](https://www.michaelkoberst.com/assets/papers/ms-thesis-michael-oberst.pdf)