/
vae.py
179 lines (142 loc) · 6.65 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import math
from chainer.functions.activation import softplus
from chainer.functions.math import average
from chainer.functions.math import exponential
from chainer.functions.math import sum
def gaussian_kl_divergence(mean, ln_var, reduce='sum'):
"""Computes the KL-divergence of Gaussian variables from the standard one.
Given two variable ``mean`` representing :math:`\\mu` and ``ln_var``
representing :math:`\\log(\\sigma^2)`, this function calculates
the KL-divergence in elementwise manner between the given multi-dimensional
Gaussian :math:`N(\\mu, S)` and the standard Gaussian :math:`N(0, I)`
.. math::
D_{\\mathbf{KL}}(N(\\mu, S) \\| N(0, I)),
where :math:`S` is a diagonal matrix such that :math:`S_{ii} = \\sigma_i^2`
and :math:`I` is an identity matrix.
The output is a variable whose value depends on the value of
the option ``reduce``. If it is ``'no'``, it holds the elementwise
loss values. If it is ``'sum'`` or ``'mean'``, loss values are summed up
or averaged respectively.
Args:
mean (:class:`~chainer.Variable` or :ref:`ndarray`):
A variable representing mean of given
gaussian distribution, :math:`\\mu`.
ln_var (:class:`~chainer.Variable` or :ref:`ndarray`):
A variable representing logarithm of
variance of given gaussian distribution, :math:`\\log(\\sigma^2)`.
reduce (str): Reduction option. Its value must be either
``'sum'``, ``'mean'`` or ``'no'``. Otherwise, :class:`ValueError`
is raised.
Returns:
~chainer.Variable:
A variable representing KL-divergence between
given gaussian distribution and the standard gaussian.
If ``reduce`` is ``'no'``, the output variable holds array
whose shape is same as one of (hence both of) input variables.
If it is ``'sum'`` or ``'mean'``, the output variable holds a
scalar value.
"""
if reduce not in ('sum', 'mean', 'no'):
raise ValueError(
'only \'sum\', \'mean\' and \'no\' are valid for \'reduce\', but '
'\'%s\' is given' % reduce)
var = exponential.exp(ln_var)
mean_square = mean * mean
loss = (mean_square + var - ln_var - 1) * 0.5
if reduce == 'sum':
return sum.sum(loss)
elif reduce == 'mean':
return average.average(loss)
else:
return loss
def bernoulli_nll(x, y, reduce='sum'):
"""Computes the negative log-likelihood of a Bernoulli distribution.
This function calculates the negative log-likelihood of a Bernoulli
distribution.
.. math::
-\\log B(x; p) = -\\sum_i \\{x_i \\log(p_i) + \
(1 - x_i)\\log(1 - p_i)\\},
where :math:`p = \\sigma(y)`, :math:`\\sigma(\\cdot)` is a sigmoid
function, and :math:`B(x; p)` is a Bernoulli distribution.
The output is a variable whose value depends on the value of
the option ``reduce``. If it is ``'no'``, it holds the elementwise
loss values. If it is ``'sum'`` or ``'mean'``, loss values are summed up
or averaged respectively.
.. note::
As this function uses a sigmoid function, you can pass a result of
fully-connected layer (that means :class:`Linear`) to this function
directly.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable.
y (:class:`~chainer.Variable` or :ref:`ndarray`): A variable
representing the parameter of Bernoulli distribution.
reduce (str): Reduction option. Its value must be either
``'sum'``, ``'mean'`` or ``'no'``. Otherwise, :class:`ValueError`
is raised.
Returns:
~chainer.Variable:
A variable representing the negative log-likelihood.
If ``reduce`` is ``'no'``, the output variable holds array
whose shape is same as one of (hence both of) input variables.
If it is ``'sum'`` or ``'mean'``, the output variable holds a
scalar value.
"""
if reduce not in ('sum', 'mean', 'no'):
raise ValueError(
'only \'sum\', \'mean\' and \'no\' are valid for \'reduce\', but '
'\'%s\' is given' % reduce)
loss = softplus.softplus(y) - x * y
if reduce == 'sum':
return sum.sum(loss)
elif reduce == 'mean':
return average.average(loss)
else:
return loss
def gaussian_nll(x, mean, ln_var, reduce='sum'):
"""Computes the negative log-likelihood of a Gaussian distribution.
Given two variable ``mean`` representing :math:`\\mu` and ``ln_var``
representing :math:`\\log(\\sigma^2)`, this function computes in
elementwise manner the negative log-likelihood of :math:`x` on a
Gaussian distribution :math:`N(\\mu, S)`,
.. math::
-\\log N(x; \\mu, \\sigma^2) =
\\log\\left(\\sqrt{(2\\pi)^D |S|}\\right) +
\\frac{1}{2}(x - \\mu)^\\top S^{-1}(x - \\mu),
where :math:`D` is a dimension of :math:`x` and :math:`S` is a diagonal
matrix where :math:`S_{ii} = \\sigma_i^2`.
The output is a variable whose value depends on the value of
the option ``reduce``. If it is ``'no'``, it holds the elementwise
loss values. If it is ``'sum'`` or ``'mean'``, loss values are summed up
or averaged respectively.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable.
mean (:class:`~chainer.Variable` or :ref:`ndarray`): A variable
representing mean of a Gaussian distribution, :math:`\\mu`.
ln_var (:class:`~chainer.Variable` or :ref:`ndarray`): A variable
representing logarithm of variance of a Gaussian distribution,
:math:`\\log(\\sigma^2)`.
reduce (str): Reduction option. Its value must be either
``'sum'``, ``'mean'`` or ``'no'``. Otherwise, :class:`ValueError`
is raised.
Returns:
~chainer.Variable:
A variable representing the negative log-likelihood.
If ``reduce`` is ``'no'``, the output variable holds array
whose shape is same as one of (hence both of) input variables.
If it is ``'sum'`` or ``'mean'``, the output variable holds a
scalar value.
"""
if reduce not in ('sum', 'mean', 'no'):
raise ValueError(
'only \'sum\', \'mean\' and \'no\' are valid for \'reduce\', but '
'\'%s\' is given' % reduce)
x_prec = exponential.exp(-ln_var)
x_diff = x - mean
x_power = (x_diff * x_diff) * x_prec * -0.5
loss = (ln_var + math.log(2 * math.pi)) / 2 - x_power
if reduce == 'sum':
return sum.sum(loss)
elif reduce == 'mean':
return average.average(loss)
else:
return loss