## Discrete sampling as noise + argmax

Say you have log-probabilities (logits) $\phi_i$ for a categorical distribution $\pi$ you want to use to generate a sample $z$. You could undo the log and normalize the logits to sample $z$ directly from the density

$$
\pi_i = \frac{\exp(\phi_i)}{\sum_j \exp(\phi_j)}
$$

However, it turns out that you can achieve the same by adding i.i.d noise $g_i \sim \text{Gumbel}(0, 1)$ to the logits and perform an argmax


$$
P \left[ \text{arg max}_j  \phi_j + g_j = i \right] = \pi_i
$$

this is called the Gumbel-Max trick. 

In [1]:
#| echo: false
#| output: false
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [56]:
#| code-fold: false
# Categorical distribution
pi = torch.tensor([0.15, 0.4, 0.45])
logits = torch.log(pi)

N = 100_000     # no. of samples
C = pi.shape[0] # no. of categories

# Now the trick
def gumbel_noise(shape): # Gumbel(0,1)
    u = torch.rand(shape)
    return -torch.log(-torch.log(u))


g = gumbel_noise((N, C))
sample = torch.argmax(logits + g, dim = -1)

# Verify
sample.unique(return_counts=True)[1] / N

tensor([0.1502, 0.3990, 0.4507])

## Differentiable soft sampling

By itself, the Gumbel-Max trick is not very useful (verify?). However if we replace the argmax by its differentiable approximation — the [soft(arg)max](https://en.wikipedia.org/wiki/Softmax_function#Smooth_arg_max) — we gain the ability to backpropagate through the sampling.

$$
\pi_i \approx \text{softmax}_\tau(\phi_i + g_i) = \frac{\exp((\phi_i + g_i)/\tau)}{\sum_j \exp((\phi_j + g_i) / \tau)}
$$


We also introduce a "temperature" parameter $\tau$ that controls how accurately we .... With $\tau \to 0$ the output becomes the one-hot encoded class and when $\tau \to \infty$ the output approaches the uniform.



In [94]:
temp = 0.1
g = gumbel_noise((N, C))
soft_sample = F.softmax((logits + g) / temp, dim = -1)

torch.round(soft_sample, decimals=2)

tensor([[0.0000, 0.0200, 0.9800],
        [0.0000, 0.4300, 0.5700],
        [1.0000, 0.0000, 0.0000],
        ...,
        [0.0700, 0.0000, 0.9300],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.9200, 0.0800]])

In [None]:
soft_sample.mean(dim=0) # close to pi

tensor([0.1511, 0.4001, 0.4489])

In [None]:
for temp in [0.01, 0.1, 1, 10]:
    g = gumbel_noise((N, C))
    soft_sample = F.softmax((logits + g) / temp, dim = -1)
    print(temp, '\t', soft_sample.mean(dim=0))

0.01 	 tensor([0.1496, 0.3990, 0.4513])
0.05 	 tensor([0.1497, 0.4006, 0.4498])
0.1 	 tensor([0.1508, 0.4000, 0.4492])
1 	 tensor([0.1951, 0.3871, 0.4178])
10 	 tensor([0.3107, 0.3428, 0.3465])


- why keep the gumbel noise?
- applications??

As you can see, because of the softmax, we get soft or fuzzy sampling. That is, we don't get a single class but a distribution over them as output.

TODO some applications here



## Discrete and differentiable sampling

However, some applications require we discretize and actually just choose one element from the distribution *and* we be able to backpropagate through the sampling. We could try taking the argmax of the softmax output, but we are back where we started since the argmax is not differentiable. 


The trick is to use the straight-through estimator, that is ...

In [None]:
gumb_softmax_sample = F.softmax((logits + g) / temp, dim = -1)
gumb_hard_sample = torch.argmax(gumb_softmax_sample, dim=-1)

# Straight-through estimator

forward = hard_value
gradients = soft_value
x = (hard_value - soft_value).detach() + soft_value


tensor([0.1509, 0.3994, 0.4497])

## Applications

Sources: 
- https://arxiv.org/pdf/1611.01144
- https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/functional.py#L2146