-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
bernoulli.py
177 lines (139 loc) · 5.29 KB
/
bernoulli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import numpy
import chainer
from chainer import backend
from chainer.backends import cuda
from chainer import distribution
import chainer.distributions.utils
from chainer.functions.activation import sigmoid
from chainer.functions.array import where
from chainer.functions.math import exponential
from chainer.functions.math import logarithm_1p
from chainer.functions.math import sum
from chainer import utils
from chainer.utils import cache
class BernoulliLogProb(chainer.function_node.FunctionNode):
def __init__(self, binary_check=False):
super(BernoulliLogProb, self).__init__()
self.binary_check = binary_check
def forward(self, inputs):
logit, x = inputs
self.retain_inputs((0, 1))
xp = backend.get_array_module(x)
y = logit * (x - 1) - xp.log(xp.exp(-logit) + 1)
y = utils.force_array(y)
# extreme logit
logit_isinf = xp.isinf(logit)
self.logit_ispinf = xp.bitwise_and(logit_isinf, logit > 0)
self.logit_isminf = xp.bitwise_and(logit_isinf, logit <= 0)
with numpy.errstate(divide='ignore', invalid='raise'):
y = xp.where(self.logit_ispinf, xp.log(x), y)
y = xp.where(self.logit_isminf, xp.log(1 - x), y)
if self.binary_check:
self.invalid = utils.force_array(xp.bitwise_and(x != 0, x != 1))
y[self.invalid] = -xp.inf
return utils.force_array(y, logit.dtype),
def backward(self, indexes, grad_outputs):
gy, = grad_outputs
logit, x = self.get_retained_inputs()
xp = backend.get_array_module(x)
dlogit = x - 1. / (1. + exponential.exp(-logit))
# extreme logit
nan = xp.array(xp.nan).astype(dlogit.dtype)
logit_isinf = xp.bitwise_or(self.logit_ispinf, self.logit_isminf)
dlogit = where.where(logit_isinf, nan, dlogit)
if self.binary_check:
dlogit = where.where(self.invalid, nan, dlogit)
return sum.sum_to(gy * dlogit, logit.shape), None
def _bernoulli_log_prob(logit, x, binary_check=False):
y, = BernoulliLogProb(binary_check).apply((logit, x))
return y
class Bernoulli(distribution.Distribution):
"""Bernoulli Distribution.
The probability mass function of the distribution is expressed as
.. math::
P(x = 1; p) = p \\\\
P(x = 0; p) = 1 - p
Args:
p(:class:`~chainer.Variable` or :ref:`ndarray`): Parameter of
distribution representing :math:`p`. Either `p` or `logit` (not
both) must have a value.
logit(:class:`~chainer.Variable` or :ref:`ndarray`) Parameter of
distribution representing :math:`\\log\\{p/(1-p)\\}`. Either `p`
or `logit` (not both) must have a value.
"""
def __init__(self, p=None, logit=None, binary_check=False):
super(Bernoulli, self).__init__()
if not (p is None) ^ (logit is None):
raise ValueError(
'Either `p` or `logit` (not both) must have a value.')
self.__p = p
self.__logit = logit
self.binary_check = binary_check
@cache.cached_property
def p(self):
if self.__p is not None:
return chainer.as_variable(self.__p)
else:
return sigmoid.sigmoid(self.logit)
@cache.cached_property
def logit(self):
if self.__logit is not None:
return chainer.as_variable(self.__logit)
else:
return exponential.log(self.p) - logarithm_1p.log1p(-self.p)
@property
def batch_shape(self):
return self.p.shape
@property
def entropy(self):
p = self.p
q = p.dtype.type(1.) - p
return (- chainer.distributions.utils._modified_xlogx(p)
- chainer.distributions.utils._modified_xlogx(q))
@property
def event_shape(self):
return ()
@property
def _is_gpu(self):
return isinstance(self.p.array, cuda.ndarray)
def log_prob(self, x):
return _bernoulli_log_prob(self.logit, x, self.binary_check)
@cache.cached_property
def mean(self):
return self.p
@property
def params(self):
return {'logit': self.logit}
def prob(self, x):
x = chainer.as_variable(x)
prob = x * self.p + (1 - x) * (1 - self.p)
if self.binary_check:
if self._is_gpu:
valid = cuda.cupy.bitwise_or(x.array == 0, x.array == 1)
else:
valid = numpy.bitwise_or(x.array == 0, x.array == 1)
prob *= valid
return prob
def sample_n(self, n):
if self._is_gpu:
eps = cuda.cupy.random.binomial(
1, self.p.array, size=(n,)+self.p.shape)
else:
eps = numpy.random.binomial(
1, self.p.array, size=(n,)+self.p.shape)
return chainer.Variable(eps)
@cache.cached_property
def stddev(self):
return self.variance ** 0.5
@property
def support(self):
return '{0, 1}'
@cache.cached_property
def variance(self):
return self.p * (1 - self.p)
@distribution.register_kl(Bernoulli, Bernoulli)
def _kl_bernoulli_bernoulli(dist1, dist2):
return (
(dist1.logit - dist2.logit) * (dist1.p - 1.)
- exponential.log(exponential.exp(-dist1.logit) + 1)
+ exponential.log(exponential.exp(-dist2.logit) + 1))