# Computing the implicit gradient w.r.t hyperparameter

## Elementary theory

Consider $\mathbb{R}$-normed spaces $(E, \|\cdot\|)$ and $(F, \|\cdot\|)$,
and a map $\phi\colon U \to F$, where $U$ is open set in $E$.

The map $\phi$ is Frechet differentiadle at $a\in U$ if there exists a
bounded linear map $\ell\colon E \to F$ such that for any $\varepsilon > 0$
there is $\delta > 0$ such that for any $\|h\| \leq \delta$
with $a+h \in U$ we have
$$
    \bigl\|
        \phi(a+h) - \phi(a)
        - \ell(h)
    \bigr\| \leq \varepsilon \|h\|
    \,. $$

This map is unique and called the differential of $\phi$ at $a$, i.e. $d\phi(a) = \ell$.

Note that, since $U$ is open (in the norm topology) and $a\in U$ we can
always choose a smaller $\delta$ so that for all $\|h\|\leq \delta$ it
holds that $a + h \in U$.

As an operator, $d\phi\colon U \to \mathcal{B}(E, F)$, is a mapping
from the open domain of $\phi$ into the space of linear operators.
Note that the space of linear operators $\mathcal{B}(E, F)$ is also
normed: for $\ell \in \mathcal{B}(E, F)$ we have
$$
    \|\ell\|
        = \sup\bigl\{
            \tfrac{\|\ell x\|}{\|x\|}
            \colon x \in E\,,\,x\neq 0
        \bigr\}
        = \sup\bigl\{
            \|\ell x\|
            \colon x \in E\,,\,\|x\|\leq 1
        \bigr\}
        = \inf\bigl\{
            K \geq 0
            \colon \|\ell x \| \leq K \|x\|\,,\, \forall x\in E
        \bigr\}
    \,. $$


Therefore applying the definition to $a \mapsto d\phi(a)$, we get
that the second order differential of $\phi$, aka the differential
of $a\mapsto d\phi(a)$, denoted by $d^2\phi$ at $a$ is an element
of $\mathcal{B}(E, \mathcal{B}(E, F))$, i.e. a bilinear map. The 
definition goes: $\phi$ is twice differentiable at $a$ if $\phi$
is differentiable on $U$ and the map $d\phi$ is differentiable as
$a$, i.e. for any $\varepsilon > 0$ there is $\delta > 0$ such that
for all $\|h\| \leq \delta$ such that $a + h \in U \cap E$ we have

$$
\bigl\|
    d\phi(a+h) - d\phi(a)
    - d(d\phi(a))(h)
\bigr\| \leq \varepsilon \|h\|
    \,. $$

The linear map $d(d\phi(a))$ takes linear maps as values, and so
$h\mapsto d(d\phi(a))(h)$ is shortened to $h\mapsto d^2\phi(a)(h, \cdot)$.

The trick is to notice that for any fixed $h\in E$ we can
put $\psi_h \colon U \to F \colon a\mapsto d\phi(a)(h)$.
Then $d\psi_h(a)$ is $d^2\phi(a)(\cdot, h)$.

Indeed fix some $h \in E$. Then, formally, for any $\varepsilon > 0$
there exists $\delta > 0$ such that for any $\|v\| \leq \delta$ with
$a+v \in U\cap E$ we have the linearization for $\psi_h(a)$ and
$d\phi(a)$. Then

$$
\begin{align}
    \bigl\|
        d\psi_h(a)(v) - d^2\phi(a)(v, h)
    \bigr\|
        &\leq \bigl\|
            d\phi(a + v)(h) - d\phi(a)(h) - d^2\phi(a)(v, h)
        \bigr\| + \bigl\|
            d\phi(a + v)(h) - d\phi(a)(h) - d\psi_h(a)(v)
        \bigr\|
        \\
        &\leq \bigl\|
            d\phi(a + v) - d\phi(a) - d^2\phi(a)(v, \cdot)
        \bigr\| \| h \| + \bigl\|
            \psi_h(a + v) - \psi_h(a) - d\psi_h(a)(v)
        \bigr\|
        \\
        &\leq \varepsilon \|v\| (1 + \| h \|)
\end{align}
    \,. $$

Then for any $\|x\| \leq 1$ we have $\|\delta x\|\leq \delta$ and

$$
\bigl\|
    d\psi_h(a)(x) - d^2\phi(a)(x, h)
\bigr\|
    = \tfrac1\delta \bigl\|
        d\psi_h(a)(\delta x) - d^2\phi(a)(\delta x, h)
    \bigr\|
    \leq \tfrac1\delta \varepsilon (1 + \| h \|) \|\delta x \|
    \,. $$

Since $\delta > 0$ the term $\varepsilon (1 + \| h \|)$ upper bounds
the norm for all $\|x\| \leq 1$, and thus the operator norm is also
bounded. Therefore for any $\varepsilon > 0$ we have

$$
\bigl\|
    d\psi_h(a) - d^2\phi(a)(\cdot, h)
\bigr\|
    \leq \varepsilon (1 + \| h \|)
    \,, $$
which implies that $d\psi_h(a) = d^2\phi(a)(\cdot, h)$

**(duh)**
This implies that we cat take $a\mapsto d\phi(a)$, evaluate it at
some constant vector $h$ to get $a\mapsto d\phi(a)(h)$, and
differentiale. In $\mathbb{R}^n$ this means that differentiating
$a \mapsto \nabla \phi(a)^\top h$ yields $\nabla^2 \phi(a) h$.

To find $\nabla^2_{\lambda\omega} f(\omega, \lambda) v$ do:
1. fix $v$ and compute $
\phi_v
\colon (\omega, \lambda) \mapsto
    \nabla_\omega f(\omega, \lambda)^\top v
$
2. compute $
\nabla_\lambda \phi_v(\omega, \lambda)
    = \nabla_\lambda \bigl( \nabla_\omega f(\omega, \lambda)^\top v \bigr)
    = \nabla^2_{\lambda\omega} f(\omega, \lambda) v
$

In [None]:
import numpy as np

import torch

import torch.nn.functional as F

<br>

## Primitives

linear operatons in the direct-product space of inner product spaces:
$( E, [\cdot, \cdot])$ with $E = \prod_i E_i$ and $
[\cdot, \cdot]
    = \oplus_i \langle \cdot, \cdot \rangle_i
$
with each $(E_i, \langle \cdot, \cdot \rangle_i)$ being an inner product
space over $\mathbb{K}$:

* (`daxpy`, `dscal`) $a, b \in E$, and $\lambda \in \mathbb{K}$ we have $\lambda a + b \in E$
given by $(\lambda a + b)_i = \lambda a_i + b_i$ in each $E_i$.

* (`ddot`, `dnorm`) for $a, b \in E$ we have $[a, b] = \sum_i \langle a_i, b_i \rangle_i$
and $\|a\| = \sqrt{[a, a]}$

In [None]:
from tools.trcg import clone, dzero, daxpy, ddot

from tools.trcg import trcg

<br>

Get the data:

In [None]:
import torch.utils.data

X = torch.randn(1024, 10).double()
y = X @ torch.randn(10, 1).double()

train = torch.utils.data.TensorDataset(X[:25], y[:25])
test = torch.utils.data.TensorDataset(X[250:], y[250:])

<br>

setup losses.

In [None]:
alpha = torch.tensor(1., requires_grad=True).double()

#### What about within-model hyperparameter?

In [None]:
def mse_loss(model, X, y):
    l2_reg = sum(p.norm()**2 for n, p in model.named_parameters()
                 if "bias" not in n)
    return 0.5 * F.mse_loss(model(X), y) + 0.5 * alpha * l2_reg

In [None]:
loss = mse_loss

In [None]:
from torch.nn import Sequential, Linear, LeakyReLU, Tanh

model = Sequential(
    Linear(10, 64),
    LeakyReLU(),
    Linear(64, 64),
    LeakyReLU(),
    Linear(64, 1),
#     Linear(10, 1, bias=False)
).double()

<br>

## Hyperparameter optimization with approximate gradient

Observations
* at $\omega^*(\lambda) = \arg\min_\omega \mathcal{L}_\mathrm{train}(\omega, \lambda)$ we have $\nabla_\omega \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda) \equiv 0$, i.e.

$$
    \nabla_\lambda \omega^*(\lambda) 
        \nabla^2_{\omega\omega} \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda)
    + \nabla^2_{\lambda\omega} \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda)
        = 0
    \,, $$


* using the chain rule we get for $
F\colon \lambda \mapsto \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
$

$$
    \nabla_\lambda F(\lambda)
        = \nabla_\lambda \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
        + \nabla_\lambda \omega^*(\lambda)
            \, \nabla_\omega \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
    \,. $$

[HOAG](https://arxiv.org/pdf/1602.02355.pdf) Algorithm:
1. (approx) solve $\omega^*(\lambda) = \arg\min \mathcal{L}_\mathrm{train}(\omega, \lambda)$
2. (approx) find such $q_\lambda$ that
$
    \nabla^2_{\omega\omega} \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda)
        \, q_\lambda
    = \nabla_\omega \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
$ (can add $c I_\omega$ for $0 < c \ll 1$ to the hessian so stabilize CG)
3. compute
$$
    \nabla_\lambda F(\lambda)
        = \nabla_\lambda \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
        - \nabla^2_{\lambda\omega} \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda) \, q_\lambda
    \,. $$

Therefore 

$$
    \nabla_\lambda F(\lambda)
        = \nabla_\lambda \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
        - \nabla^2_{\lambda\omega} \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda)
        \underbrace{
            \bigl(
                \nabla^2_{\omega\omega} \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda)
            \bigr)^{-1}
            \nabla_\omega \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
        }_{q_\lambda}
    \,. $$

<br>

### Steps

#### 1. (approx) find stationary point

$\omega^*(\lambda)$ sits in `model.parameters()`

**(note)**
we have to train the model to $
\nabla_\omega \mathcal{L}_\mathrm{train}
    (\omega^*(\lambda), \lambda)
    \approx 0
$
(almost zero gradient), and preferably to a local minimum (so that CG
below uses a positive definite hessian).

In [None]:
from mlss2019bdl import fit

model.zero_grad()

# 1. find (approx) \arg \min_\omega L_train(\omega, \lambda)
fit(model, train, criterion=loss, batch_size=32, n_epochs=1000, verbose=True)

In [None]:
loss_h = loss(model, *train.tensors)  # L_train(\omega^*(\lambda), \lambda)
grad_h_omega = torch.autograd.grad(loss_h, model.parameters(), create_graph=False)

sum(map(torch.norm, grad_h_omega))

<br>

#### 2. (approx) find $q_\lambda$

Observe 
$$
    \nabla^2_{\omega\omega}
        h(\omega, \lambda) \delta
        = \tfrac{\partial}{\partial \omega}
            \bigl\{ \nabla_\omega^\top h(\omega, \lambda) \delta \bigr \}
        = \tfrac{\partial}{\partial \omega} \ell_\delta(\omega)
        = d\bigl\{ dh(\omega, \lambda)(\delta) \bigr \}
    \,. $$

In [None]:
# 2.2 create a hess-vect closure
def get_hesv_op(model, dataset):
    # get \nabla_\omega L_train(\omega^*(\lambda), \lambda)
    loss_h = loss(model, *dataset.tensors)
    grad_h_omega = torch.autograd.grad(loss_h, model.parameters(), create_graph=True)

    def _closure(delta, nugget=1e-6):  # adding a tiny diagonal helps a lot
        with torch.enable_grad():
            # \omega \mapsto \nabla_\omega^\top h(\omega, \lambda) \delta
            grad_h_vect = ddot(grad_h_omega, delta)

            # get \nabla^2_{\omega\omega} L_train(\omega^*(\lambda), \lambda) \delta
            hesv_h_omega = torch.autograd.grad(grad_h_vect, model.parameters(), retain_graph=True)

            return daxpy(nugget, delta, hesv_h_omega)

    return _closure

<br>

Evaluate test loss and find the omega-gradient at the star-point $\omega^*(\lambda)$

In [None]:
loss_g = loss(model, *test.tensors)  # L_test(\omega^*(\lambda), \lambda)

grad_g_omega = torch.autograd.grad(loss_g, model.parameters(), create_graph=False)

<br>

Solve $
\nabla^2_{\omega\omega}
    \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda)
    \,\, \mathbf{x}
    = \nabla_\omega \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
$

In [None]:
Ax = get_hesv_op(model, train)  # lambda v: [-p for p in v]

x, r = dzero(clone(grad_g_omega)), clone(grad_g_omega)
print(trcg(Ax, r, x, rtol=1e-8, atol=1e-9, verbose=True))

assert all(map(torch.allclose, Ax(x), grad_g_omega))

#### check the gradinet norms

In [None]:
loss_h = loss(model, *train.tensors)  # L_train(\omega^*(\lambda), \lambda)
grad_h_omega = torch.autograd.grad(loss_h, model.parameters(), create_graph=False)

In [None]:
sum(map(torch.norm, grad_h_omega)), sum(map(torch.norm, grad_g_omega))

<br>

#### 3. (approx) find the gradient


$$
    \nabla_\lambda F(\lambda)
        = \nabla_\lambda \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
        - \nabla^2_{\lambda\omega} \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda)
        \underbrace{
            \bigl(
                \nabla^2_{\omega\omega} \mathcal{L}_\mathrm{train}(\omega^*(\lambda), \lambda)
            \bigr)^{-1}
            \nabla_\omega \mathcal{L}_\mathrm{test}(\omega^*(\lambda), \lambda)
        }_{q_\lambda}
    \,. $$

Observing
$$
    \nabla^2_{\lambda\omega}
        h(\omega, \lambda) \delta
        = \tfrac{\partial}{\partial \lambda}
            \bigl\{ \nabla_\omega^\top h(\omega, \lambda) \delta \bigr \}
        = \tfrac{\partial}{\partial \lambda} \ell_\delta(\omega)
        = d\bigl\{ dh(\omega, \lambda)(\delta) \bigr \}
    \,. $$

In [None]:
def grad_lambda(q):
    # get \nabla_\omega L_train(\omega^*(\lambda), \lambda)
    loss_h = loss(model, *train.tensors)
    grad_h_omega = torch.autograd.grad(loss_h, model.parameters(), create_graph=True)

    # \lambda \mapsto \nabla_\omega^\top h(\omega, \lambda) q
    grad_h_vect = ddot(grad_h_omega, q)

    # get \nabla^2_{\lambda\omega} L_train(\omega^*(\lambda), \lambda) q
    hesv_h_lambda = torch.autograd.grad(grad_h_vect, alpha, retain_graph=False)

    # get \nabla_\lambda L_test(\omega^*(\lambda), \lambda)
    loss_g = loss(model, *test.tensors)
    grad_g_lambda = torch.autograd.grad(loss_g, alpha, create_graph=False)
    
    # return daxpy(-1, hesv_h_lambda, grad_g_lambda)
    return grad_g_lambda, hesv_h_lambda

In [None]:
grad, hesv = grad_lambda(x)
grad, hesv

In [None]:
eta = 1e-3
alpha.data -= eta * hesv[0]

In [None]:
alpha

In [None]:
assert False

<br>

# Trunk

$$
\mathcal{L}(\omega, \lambda)
    = \tfrac12 \| y - X \omega \|^2_2
    + \tfrac\lambda2 \|\omega\|^2_2
    \,. $$

$$
\nabla_\omega \mathcal{L}
    = - X^\top(y - X\omega) + \omega \lambda
    = (X^\top X + \lambda I) \, \omega - X^\top y
    \,. $$

$$
\nabla_\lambda \mathcal{L}
    = \tfrac12 \|\omega \|^2_2
    \,. $$

$$
\nabla_{\omega \lambda} \mathcal{L}
    = \nabla_{\lambda \omega} \mathcal{L}
    = \omega
    \,. $$

Now $\omega^*(\lambda) = (X^\top X + \lambda I)^{-1} X^\top y$, whence

$$
F(\lambda)
    = \mathcal{L}(\omega^*(\lambda), \lambda)
    = \tfrac12 \| y - X \omega^*(\lambda) \|^2_2
    + \tfrac\lambda2 \| \omega^*(\lambda) \|^2_2
    \,, $$

and

$$
\nabla_\lambda F(\lambda)
    = \tfrac12 \|\omega^*(\lambda)\|^2_2
    \,. $$

In [None]:
loss_h = loss(model, *train.tensors)

grad_h_omega = torch.autograd.grad(loss_h, model.parameters(), create_graph=True)

fn = ddot(grad_h_omega, [p.data for p in model.parameters()])

torch.autograd.grad(fn, alpha, retain_graph=True)

In [None]:
sum(map(torch.norm, grad_h_omega)), sum(map(torch.norm, model.parameters()))

<br>

Consider the following setup:
$$
    l^*(\lambda)
        = \min_{\omega} l(\omega, \lambda)
        = l(\omega^*(\lambda), \lambda)
    \,, $$
where
$$
    \omega^*(\lambda) = \arg \min_{\omega} l(\omega, \lambda)
    \,. $$

Suppose everything is sufficiently smooth, and $l \neq \ell$. Then differentiating
$f(\lambda) = \ell(\omega^*(\lambda), \lambda)$:
$$
    \partial_\lambda f(\lambda)
        = \partial_1 \ell(\omega^*, \lambda) \circ \partial_\lambda \omega^*
        + \partial_2 \ell(\omega^*, \lambda)
    \,. $$

In finite dimensions we have
$$
    \nabla_\lambda f
        = \nabla_\lambda \omega^*(\lambda) \nabla_\omega \ell(\omega, \lambda)
            \big\vert_{\omega^*(\lambda), \lambda}
        + \nabla_\lambda \ell(\omega, \lambda)
            \big\vert_{\omega^*(\lambda), \lambda}
    \,. $$

Otoh

$$
    \nabla_\lambda l^*(\lambda)
        = \nabla_\lambda \omega^*(\lambda) \nabla_\omega l(\omega, \lambda)
            \big\vert_{\omega^*(\lambda), \lambda}
        + \nabla_\lambda l(\omega, \lambda)
            \big\vert_{\omega^*(\lambda), \lambda}
        = 0
    \,, $$
whence
$$
    \nabla_\lambda \omega^*(\lambda) \nabla_\omega l(\omega, \lambda)
            \big\vert_{\omega^*(\lambda), \lambda}
        = - \nabla_\lambda l(\omega, \lambda)
            \big\vert_{\omega^*(\lambda), \lambda}
    \,. $$

Gateaux derivative
$$
    \tfrac{d}{d\eta } \omega^*(\lambda + \eta v) \big \vert_{\eta = 0}
        = \lim_{\eta \to 0} \frac{
            \arg \min_{\omega} l(\omega, \lambda + \eta v)
            - \arg \min_{\omega} l(\omega, \lambda)
        }{\eta}
\,. $$

<br>

# Using a class

In [None]:
from numpy import prod

class vector(tuple):
    def __new__(cls, *tensors):
        assert tensors
        if isinstance(tensors[0], cls):
            return tensors[0]

        if isinstance(tensors[0], (tuple, list)):
            tensors = tensors[0]

        self = super().__new__(cls, tensors)
        self.shape = tuple(u.shape for u in self)
        return self

    def __pos__(self):
        return self

    def __neg__(self):
        return type(self)(*map(torch.neg, self))

    def __add__(self, other):
        return type(self)(*map(torch.add, self, other))
    __radd__ = __add__

    def __iadd__(self, other):
        return type(self)(*map(torch.Tensor.add_, self, other))

    def __sub__(self, other):
        return type(self)(*map(torch.sub, self, other))

    def __isub__(self, other):
        return type(self)(*map(torch.Tensor.sub_, self, other))

    def __mul__(self, alpha):
        return type(self)([torch.mul(u, alpha) for u in self])
    __rmul__ = __mul__

    def __imul__(self, alpha):
        return type(self)([u.mul_(alpha) for u in self])

    def __div__(self, other):
        return type(self)([torch.div(u, alpha) for u in self])

    def __idiv__(self, other):
        return type(self)([u.div_(alpha) for u in self])

    def clone(self):
        return type(self)(*map(torch.clone, self))

    def zero_(self):
        return type(self)(*map(torch.zero_, self))

    def detach(self):
        return type(self)(*map(torch.detach, self))

    def requires_grad_(self, requires_grad=True):
        return type(self)([u.requires_grad_(requires_grad) for u in self])

    def to(self, *args, **kwargs):
        return type(self)([u.to(*args, **kwargs) for u in self])
    
    def tensor(self):
        return torch.cat([u.flatten() for u in self])
    
    @classmethod
    def from_tensor(cls, shapes):
        flat = torch.split(a, [*map(prod, shape)])
        return cls([u.reshape(s) for u, s in zip(flat, shape)])

    def __matmul__(self, other):
        return torch.dot(self.tensor(), other.tensor())

    def __abs__(self):
        return torch.norm(self.tensor())

    def __len__(self):
        return sum(u.numel() for u in x)

In [None]:
from math import sqrt

def trcg_vec(Ax, r, x, n_iterations=1000, tr_delta=0, rtol=1e-5, atol=1e-8,
         args=(), verbose=False):
    if n_iterations > 0:
        n_iterations = min(n_iterations, len(x))

    p, iteration = r.clone(), 0
    tr_delta_sq = tr_delta ** 2

    rtr, rtr_old = float(r @ r), 1.0
    cg_tol = sqrt(rtr) * rtol + atol
    region_breached = False
    while (iteration < n_iterations) and (sqrt(rtr) > cg_tol):
        Ap = vector(Ax(p, *args))
        iteration += 1
        if verbose:
            print("""iter %2d |Ap| %5.3e |p| %5.3e """
                  """|r| %5.3e |x| %5.3e beta %5.3e""" %
                  (iteration, abs(Ap), abs(p), abs(r), abs(x), rtr / rtr_old))
        # end if

        alpha = rtr / float(p @ Ap)
        x += alpha * p
        r -= alpha * Ap

        # check trust region (diverges from tron.cpp in liblinear and leml-imf)
        if tr_delta_sq > 0:
            xTx = float(x @ x)
            if xTx > tr_delta_sq:
                xTp = float(x @ p)
                if xTp > 0:
                    # backtrack into the trust region
                    p_nrm = abs(p)

                    q = xTp / p_nrm
                    eta = (q - sqrt(max(q * q + tr_delta_sq - xTx, 0))) / p_nrm

                    # reproject onto the boundary of the region
                    r += eta * Ap
                    x -= eta * p
                else:
                    # this never happens maybe due to CG iteration properties
                    pass
                # end if

                region_breached = True
                break
            # end if
        # end if

        rtr, rtr_old = float(r @ r), rtr
        p *= rtr / rtr_old
        p += r
    # end while

    return iteration, region_breached

In [None]:
x, r = vector(grad_g_omega).clone().zero_(), vector(grad_g_omega).clone()
print(trcg_vec(Ax, r, x, rtol=1e-9, atol=1e-9, verbose=True))

assert all(map(torch.allclose, Ax(x), grad_g_omega))

In [None]:
class Term(object):
    def __init__(self):
        self.losses = set()

    def __add__(self, other):
        self.losses.add(other)
        return self
    
    def __call__(self, model, X, y):
        if self.losses:
            return sum(loss.forward(model, X, y) for loss in self.losses)
        return self.forward(model, X, y)

In [None]:
class MSELoss(Term):
    def forward(self, model, X, y):
        return 0.5 * F.mse_loss(model(X), y)

<br>

# Splitting alphas and weights

In [None]:
def get_parameters(module, prefix=""):
    for name, par in module.named_parameters(prefix):
        pass
    pass

<br>

# Gumbel-softmax trick

```python
    prob_t = tt.nnet.softmax(logit_t)

    # Gumbel-softmax sampling: Gumbel (e^{-e^{-x}}) distributed random noise
    gumbel = -tt.log(-tt.log(theano_random_state.uniform(size=logit_t.shape) + eps) + eps)
#     logit_t = theano.ifelse.ifelse(tt.gt(tau, 0), gumbel + logit_t, logit_t)
#     inv_temp = theano.ifelse.ifelse(tt.gt(tau, 0), 1.0 / tau, tt.constant(1.0))
    logit_t = tt.switch(tt.gt(tau, 0), gumbel + logit_t, logit_t)
    inv_temp = tt.switch(tt.gt(tau, 0), 1.0 / tau, tt.constant(1.0))

    # Get the softmax: x_t is `BxV`
    x_t = tt.nnet.softmax(logit_t * inv_temp)
```

In [None]:
temp = 1e-1
logit = torch.randn(10, 10)

gumbel = -torch.log(-torch.log(torch.rand_like(logit)))

proba = torch.softmax((logit + gumbel) / temp, dim=-1)

<br>