-
Notifications
You must be signed in to change notification settings - Fork 8
/
gumbel_softmax.py
82 lines (51 loc) · 2.1 KB
/
gumbel_softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Gumbel-Softmax layer.
"""
from typing import Any, Dict, Tuple
import tensorflow as tf
from tensorflow.keras.layers import Layer
import nalp.utils.constants as c
def gumbel_distribution(input_shape: Tuple[int, ...]) -> tf.Tensor:
"""Samples a tensor from a Gumbel distribution.
Args:
input_shape: Shape of tensor to be sampled.
Returns:
(tf.Tensor): An input_shape tensor sampled from a Gumbel distribution.
"""
uniform_dist = tf.random.uniform(input_shape, 0, 1)
gumbel_dist = -1 * tf.math.log(
-1 * tf.math.log(uniform_dist + c.EPSILON) + c.EPSILON
)
return gumbel_dist
class GumbelSoftmax(Layer):
"""A GumbelSoftmax class is the one in charge of a Gumbel-Softmax layer implementation.
References:
E. Jang, S. Gu, B. Poole. Categorical reparameterization with gumbel-softmax.
Preprint arXiv:1611.01144 (2016).
"""
def __init__(self, axis: int = -1, **kwargs) -> None:
"""Initialization method.
Args:
axis: Axis to perform the softmax operation.
"""
super(GumbelSoftmax, self).__init__(**kwargs)
self.axis = axis
def call(self, inputs: tf.Tensor, tau: float) -> Tuple[tf.Tensor, tf.Tensor]:
"""Method that holds vital information whenever this class is called.
Args:
x: A tensorflow's tensor holding input data.
tau: Gumbel-Softmax temperature parameter.
Returns:
(Tuple[tf.Tensor, tf.Tensor]): Gumbel-Softmax output and its argmax token.
"""
x = inputs + gumbel_distribution(tf.shape(inputs))
x = tf.nn.softmax(x / tau, self.axis)
y = tf.stop_gradient(tf.argmax(x, self.axis, tf.int32))
return x, y
def get_config(self) -> Dict[str, Any]:
"""Gets the configuration of the layer for further serialization.
Returns:
(Dict[str, Any]): Configuration dictionary.
"""
config = {"axis": self.axis}
base_config = super(GumbelSoftmax, self).get_config()
return dict(list(base_config.items()) + list(config.items()))