In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F

In [2]:
# copied and modified from https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    return (y_hard - y).detach() + y

In [3]:
# an autoencoer example

class AbstractVAE(nn.Module):
    def __init__(self):
        super(self.__class__, self).__init__()
        
    def encode(self, x):
        return # mu, logvar
    
    def reparametrize(self, mu, logvar):
        return # z
    
    def decode(self, z):
        return # \hat{x}

In [37]:
class PMVAE(nn.Module):
    def __init__(self, n_approx=5, g_temp=0.1):
        super(self.__class__, self).__init__()
        
        self.n_approx = n_approx
        self.g_temp = g_temp
        
        self.fce = nn.Sequential(
            nn.Linear(784, 400),
            nn.ELU(),
        )
        
        self.mu       = nn.Linear(400, 20 * n_approx)
        self.logvar   = nn.Linear(400, 20 * n_approx)
        self.decision = nn.Linear(400, 20 * n_approx)
        
    def encode(self, x):
        h = self.fce(x)
        
        d = gumbel_softmax(
            self.decision(h).reshape(-1, self.n_approx),
            temperature = self.g_temp,
        )
        
        mu, logvar = self.mu(h).reshape(-1, 20, self.n_approx) * d,\
               self.logvar(h).reshape(-1, 20, self.n_approx) * d,
        
        return mu.sum(2), logvar.sum(2)
        
    def reparametrize(self, mu, logvar):
        pass

In [38]:
vae = PMVAE()

In [39]:
mu, logvar = vae.encode(torch.randn(1, 784))