In [8]:
import torch
from torch import nn, optim
import torch.nn.functional as F

In [10]:
# copied and modified from https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    return (y_hard - y).detach() + y

In [23]:
class WeirdEncoder(nn.Module):
    def __init__(self, n_approx=5):
        super(self.__class__, self).__init__()
        
        self.n_approx = n_approx
        
        self.layers = nn.Sequential(
            nn.Linear(784, 400),
            nn.ReLU()
        )
        
        
        self.mu = nn.Linear(400, 20 * n_approx)
        self.sigma = nn.Linear(400, 20 * n_approx)
        self.decision = nn.Linear(400, 20 * n_approx)
        
    def encode(self, x):
        x = self.layers(x)
        dec = gumbel_softmax(self.decision(x).reshape(-1, self.n_approx), temperature=0.1).reshape(-1, 20, self.n_approx)
        mu  = self.mu(x).reshape(-1, 20, self.n_approx)
        sma = self.sigma(x).reshape(-1, 20, self.n_approx)
        
        return mu, sma, dec

In [24]:
encoder = WeirdEncoder()

In [14]:
x = torch.randn(1, 20 * 5)
x = x.view(-1, 5)

out = gumbel_softmax(x, temperature=0.1)

z = out.reshape(1, 20, 5)

In [15]:
z

tensor([[[0., 0., 0., 0., 1.],
         [0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0.]]])