# Gumbel Softmax

Recall that we want to compute the gradient w.r.t. $\phi$ of the
following expectation

$$\mathbb{E}_{q_\phi(x)} \left[ f(x) \right]$$

If we draw samples from $q_\phi(x)$ then the resulting expression
does not depend on $\phi$ so we cannot compute the gradient.
We must express $x$ as the result of a deterministic function of
a random variable $z$ drawn from a distribution $p_Z(z)$ that
does not depend on $\phi$

$$x = g_\phi(z), \text{ where } z \sim p_Z(z)$$

The expectation becomes

\begin{eqnarray}
\mathbb{E}_{p_Z(z)} \left[ f(g_\phi(z)) \right] \approx
\frac{1}{K} \sum_{i = 1}^K f(g_\phi(z^{(i)}))
\end{eqnarray}

That expression depends on $\phi$ and we can compute the derivative
w.r.t. $\phi$.
Reparameterization trick can be applied to some common continuous
distributions such as Gaussian, Poisson, ...

What if $q_\phi$ is a discrete distribution, e.g. categorical
distribution? In this case $\phi = (\pi_1, ..., \pi_n)^\top$.
Samples from $q_\phi$ are one-hot vectors generated from this
process

$$x = \text{onehot}\left(\underset{i}{\max}\{i | \pi_1 + ... + \pi_i < u\}\right)$$
where $u \sim \mathcal{U}(0, 1)$.

We can apply the Gumbel-Max trick to approximate the process of drawing samples
from a categorical distribution
$$x = \text{onehot}\left( \underset{i}{\text{argmax}}(G_i + \log \pi_i) \right)$$

where $G_i$ is a sample from the standard Gumbel distribution.

- Question: why don't we train $\phi$ directly, it is a continuous vector.
- Answer: The input to the next step (e.g. the decoder) is a sample from
$q_\phi$, not the vector $\phi$ it self.
For example, we want to model human faces and $q_\phi$ is the categorical
distribution over skin colors. The input to the decoder are onehot vectors,
not continuous vectors.
$\phi$ could be $(0.45, 0.2, 0.2, 0.15)$ but samples from $q_\phi$ must be
$(1, 0, 0, 0)$ 45% of the time, $(0, 1, 0, 0)$ 20% of the time, ...


Gradient cannot pass through the discrete one-hot vector so we replace
the discrete one-hot vector with a continuous approximation to let the
gradient pass through. The continuous approximation is of the form

$$y_i = \frac{exp((\log(\pi_i) + G_i) / \tau)}{\sum_j exp(exp((\log(\pi_j) + G_j) / \tau))}$$
for $i = 1, ..., n$.
