In [36]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import numpy as np
import math
import matplotlib.pyplot as plt

In [45]:
eps = 1e-20

def rand_gumbel(shape):
    u = torch.rand(shape)
    return -torch.log(-torch.log(u + eps) + eps)


def gumbel_max_sampling(logits):
    """
    input = shape(*, n_class)
    return.shape: [*, n_class], one-hot vector
    """
    y = nn.functional.log_softmax(logits, dim=-1) + rand_gumbel(logits.shape)
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, y.shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(y.size())
    return y_hard.view(y.size())


def gumbel_softmax_sample(logits, tau):
    """
    input = shape(*, n_class)
    return.shape: [*, n_class], gumbel_softmax vector
    """
    y = logits + rand_gumbel(logits.shape)
    return nn.functional.softmax(y / tau, dim=-1)


def ST_gumbel_softmax_sample(logits, tau):
    """
    input = shape(*, n_class)
    return.shape: [*, n_class], one-hot vector
    """
    y = gumbel_softmax_sample(logits, tau)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    return (y_hard - y).detach() + y

In [63]:
LOGITS = [math.log(0.1), math.log(0.2), math.log(0.3), math.log(0.4)]
LOGITS = torch.FloatTensor([LOGITS])
num_samples = 20000

In [64]:
print('GUMBEL MAX')
samples = gumbel_max_sampling(LOGITS.repeat(num_samples, 1))
print(samples.mean(dim=0).numpy())

GUMBEL MAX
[0.0965 0.2026 0.3003 0.4006]


In [70]:
tau = 0.01
print('ST_GUMBEL_SOFTMAX, tau = ', tau)
samples = ST_gumbel_softmax_sample(LOGITS.repeat(num_samples, 1), tau)
print(samples.mean(dim=0).numpy())

ST_GUMBEL_SOFTMAX, tau =  0.01
[0.10045 0.1984  0.30675 0.3944 ]


In [71]:
tau = 0.1
print('ST_GUMBEL_SOFTMAX, tau = ', tau)
samples = ST_gumbel_softmax_sample(LOGITS.repeat(num_samples, 1), tau)
print(samples.mean(dim=0).numpy())

ST_GUMBEL_SOFTMAX, tau =  0.1
[0.1008  0.20215 0.29855 0.3985 ]


In [72]:
tau = 1.0
print('ST_GUMBEL_SOFTMAX, tau = ', tau)
samples = ST_gumbel_softmax_sample(LOGITS.repeat(num_samples, 1), tau)
print(samples.mean(dim=0).numpy())

ST_GUMBEL_SOFTMAX, tau =  1.0
[0.10035 0.2014  0.29985 0.3984 ]


In [73]:
tau = 10.0
print('ST_GUMBEL_SOFTMAX, tau = ', tau)
samples = ST_gumbel_softmax_sample(LOGITS.repeat(num_samples, 1), tau)
print(samples.mean(dim=0).numpy())

ST_GUMBEL_SOFTMAX, tau =  10.0
[0.09925 0.20085 0.2967  0.4032 ]


In [74]:
tau = 100.0
print('ST_GUMBEL_SOFTMAX, tau = ', tau)
samples = ST_gumbel_softmax_sample(LOGITS.repeat(num_samples, 1), tau)
print(samples.mean(dim=0).numpy())

ST_GUMBEL_SOFTMAX, tau =  100.0
[0.1006  0.1987  0.29975 0.40095]


In [84]:
tau = 0.01
print('GUMBEL_SOFTMAX, tau = ', tau)
samples = gumbel_softmax_sample(LOGITS.repeat(5, 1), tau)
np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
print(samples.numpy())

GUMBEL_SOFTMAX, tau =  0.01
[[ 0.000  1.000  0.000  0.000]
 [ 0.000  0.000  0.000  1.000]
 [ 0.000  1.000  0.000  0.000]
 [ 0.000  0.000  0.000  1.000]
 [ 1.000  0.000  0.000  0.000]]


In [85]:
tau = 0.1
print('GUMBEL_SOFTMAX, tau = ', tau)
samples = gumbel_softmax_sample(LOGITS.repeat(5, 1), tau)
np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
print(samples.numpy())

GUMBEL_SOFTMAX, tau =  0.1
[[ 0.078  0.877  0.038  0.006]
 [ 1.000  0.000  0.000  0.000]
 [ 0.000  0.000  0.963  0.037]
 [ 0.000  0.991  0.000  0.009]
 [ 0.999  0.000  0.001  0.000]]


In [86]:
tau = 10.0
print('GUMBEL_SOFTMAX, tau = ', tau)
samples = gumbel_softmax_sample(LOGITS.repeat(5, 1), tau)
np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
print(samples.numpy())

GUMBEL_SOFTMAX, tau =  10.0
[[ 0.204  0.244  0.289  0.264]
 [ 0.229  0.246  0.258  0.268]
 [ 0.245  0.245  0.250  0.259]
 [ 0.202  0.266  0.227  0.305]
 [ 0.212  0.221  0.283  0.284]]
