# Variational Inference with Gumbel Softmax

*\[Reference\] Learning to Explain: An Information-Theoretic Perspective on Model Interpretation, https://github.com/Jianbo-Lab/L2X*

目标：$max_\mathcal{E} I(X_S, Y)$, s.t. $S \in \mathcal{E}(X) $

$$
\begin{aligned}
& \mathbb{E} [I(X_S, Y)] \\
= & \mathbb{E} [ \frac{P(X_S, Y)}{P(X_S) \cdot P(Y)} ] \\
= & \mathbb{E} [ P(Y|X_S) ] + const \\ 
= & \mathbb{E}_{X} \mathbb{E}_{X_S|X} \mathbb{E}_{Y|X_S} [ P(Y|X_S) ] + const
\end{aligned}
$$

其中：

- $\mathbb{E}_{X}$：从真实数据分布中采样

- $\mathbb{E}_{X_S|X}$：$\mathcal{E}$ 可以看作一个 proposal，对于 $X$ 返回一个离散的子集合 $X_S$，也就是说 $\mathcal{E}(X_S|X)$ 是一个 dirac 分布。$KL(\mathcal{E}(X_S|X)|P(X_S))$ 在 $P(X_S)$ 具有 uniform 先验的情况下，是一个常数。因此可以直接通过 Monte Carlo 采样来优化 ELBO。
    - 由于 $\mathcal{E}$ 不可导，因此可以使用 gumbel softmax 进行近似。

- $\mathbb{E}_{Y|X_S} [ P(Y|X_S) ] \ge \mathbb{E}_{Y|X_S} [ Q(Y|X_S) ]$：使用 $Q$ 进行 Variational Inference，获取并优化 ELBO。（存疑？）


In [2]:
import torch
import random
from torch.utils.data import IterableDataset, DataLoader
import torch.nn.functional as F
import itertools

In [3]:
def gpu(sth):
    if torch.cuda.is_available():
        if isinstance(sth, tuple) or isinstance(sth, list):
            ret = [gpu(ele) for ele in sth]
        else:
            ret = sth.cuda()
    return ret

In [4]:
class CustomDataset(IterableDataset):
    def __iter__(self):
        while True:
            x = torch.randint(1, 10, (9, ))
            y = x[0]
            x[:5] = y
            x = x[torch.randperm(9)]
            yield x, y

data = DataLoader(CustomDataset(), batch_size=128)

In [5]:
class Feat(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = torch.nn.Embedding(10, 32, padding_idx=0)
        self.transformer = torch.nn.TransformerEncoder(
            torch.nn.TransformerEncoderLayer(d_model=32, nhead=4), 
        num_layers=1)

    def forward(self, idxs):
        h = self.embed(idxs)
        h = self.transformer(h)
        return h

# Q(Y|X_S)
class Q(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.clf = torch.nn.Sequential(
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 10)
        )

    def forward(self, hidden):
        return self.clf(hidden)

# X_S ~ E(X)
class E(torch.nn.Module):
    def __init__(self, tau, k):
        super().__init__()
        # B x T x 25
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1),
            torch.nn.Flatten(-2, -1)  # remove last dim
        )
        self.tau = tau
        self.k = k
    
    def forward(self, hidden, hard=False):
        logits = self.mlp(hidden)
        logits = F.log_softmax(logits)
        if not hard:
            # probs = F.softmax(logits, dim=1)
            # ind = torch.zeros_like(probs).scatter(1, probs.topk(self.k)[1], 1.)
            # ret = probs + (ind - probs).detach()

            logits_k = logits.unsqueeze(1).expand(logits.shape[0], self.k, logits.shape[1])
            probs = F.gumbel_softmax(logits_k, tau=self.tau)
            ret = probs.max(1)[0]
        else:
            ret = torch.zeros_like(logits)
            ret.scatter_(1, logits.topk(self.k).indices, 1)
        return ret

# hidden: B x T x H
# probs : B x T
def pool(hidden, probs=None, k=None):
    if probs is None:
        return hidden.mean(1)
    else:
        hidden = probs.unsqueeze(-1) * hidden
        return hidden.sum(1) / k


In [72]:
x, y = gpu(next(iter(data)))
e(feat(x)).sum(1)

tensor([3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3.], device='cuda:0', grad_fn=<SumBackward1>)

In [6]:
feat = gpu(Feat())
q = gpu(Q())
e = gpu(E(0.1, 3))
opt = torch.optim.AdamW(
    itertools.chain(feat.parameters(), q.parameters(), e.parameters()), 
    5e-4,
    weight_decay=0.01
)

feat.train()
q.train()
e.train()

for bid, batch in enumerate(data):
    # e.tau = max(0.1, (1000 - bid) / 1000 * 0.3)

    opt.zero_grad()
    x, y = gpu(batch)

    h = feat(x)

    # loss1 = F.cross_entropy(pool(h), y)

    probs = e(h, hard=False)
    logits = q(pool(h, probs, probs.sum(1, keepdims=True)))
    loss2 = F.cross_entropy(logits, y)

    # loss = 0.5 * loss1 + loss2
    loss = loss2

    if bid % 100 == 0:
        print(bid, e.tau, loss.item())
    loss.backward()
    opt.step()
    if bid == 1000:
        break

0 0.1 2.3150339126586914
100 0.1 1.1462016105651855
200 0.1 0.8118592500686646
300 0.1 0.7601400017738342
400 0.1 0.55384761095047
500 0.1 0.7607303261756897
600 0.1 0.651654839515686
700 0.1 0.7534162402153015
800 0.1 0.573569655418396
900 0.1 0.6712555885314941
1000 0.1 0.6235792636871338


In [9]:
feat.eval()
q.eval()
e.eval()

with torch.no_grad():
    x, y = gpu(next(iter(data)))
    x = x[:5]
    y = y[:5]

    h = feat(x)
    print(x)
    # print(hard_idx)

    # mean
    print("\nmean >>>")
    logits = q(pool(h))
    print("pred:", logits.argmax(1), y)

    # estimated
    print("\nexplained(soft) >>>")
    probs = e(h, hard=False)
    print(probs.sum(1))
    logits = q(pool(h, probs, e.k))
    print("pred:", logits.argmax(1), y)

    # estimated
    print("\nexplained(hard) >>>")
    probs = e(h, hard=True)
    logits = q(pool(h, probs, e.k))
    print("pred:", logits.argmax(1), y)

    # explain
    print("\nexplained")
    hard_idx = e(h, hard=True)
    explained = x.masked_fill(~(hard_idx.bool()), 0)
    print(explained)


tensor([[8, 3, 1, 1, 9, 1, 1, 1, 1],
        [4, 7, 7, 6, 6, 6, 6, 9, 6],
        [3, 7, 8, 7, 9, 9, 7, 7, 7],
        [6, 8, 8, 8, 1, 7, 2, 8, 8],
        [2, 5, 9, 9, 9, 9, 9, 2, 9]], device='cuda:0')

mean >>>
pred: tensor([1, 6, 7, 8, 9], device='cuda:0') tensor([1, 6, 7, 8, 9], device='cuda:0')

explained(soft) >>>
tensor([2.9987, 2.0110, 2.9154, 2.0000, 2.9857], device='cuda:0')
pred: tensor([1, 6, 7, 8, 9], device='cuda:0') tensor([1, 6, 7, 8, 9], device='cuda:0')

explained(hard) >>>
pred: tensor([3, 6, 7, 6, 2], device='cuda:0') tensor([1, 6, 7, 8, 9], device='cuda:0')

explained
tensor([[0, 3, 0, 0, 9, 1, 0, 0, 0],
        [0, 7, 0, 0, 6, 0, 6, 0, 0],
        [3, 7, 0, 0, 0, 0, 7, 0, 0],
        [6, 0, 0, 0, 0, 7, 2, 0, 0],
        [2, 0, 0, 0, 9, 0, 0, 2, 0]], device='cuda:0')
