-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
uniform.py
159 lines (128 loc) · 4.53 KB
/
uniform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import chainer
from chainer import backend
from chainer.backends import cuda
from chainer import distribution
from chainer.functions.array import broadcast
from chainer.functions.array import where
from chainer.functions.math import clip
from chainer.functions.math import exponential
from chainer.functions.math import sqrt
from chainer import utils
from chainer.utils import argument
from chainer.utils import cache
class Uniform(distribution.Distribution):
"""Uniform Distribution.
The probability density function of the distribution is expressed as
.. math::
p(x; l, h) = \\begin{cases}
\\frac{1}{h - l} & \\text{if }l \\leq x \\leq h \\\\
0 & \\text{otherwise}
\\end{cases}
Args:
low(:class:`~chainer.Variable` or :ref:`ndarray`): Parameter of
distribution representing the lower bound :math:`l`.
high(:class:`~chainer.Variable` or :ref:`ndarray`): Parameter of
distribution representing the higher bound :math:`h`.
"""
def __init__(self, **kwargs):
low, high, loc, scale = None, None, None, None
if kwargs:
low, high, loc, scale = argument.parse_kwargs(
kwargs, ('low', low), ('high', high), ('loc', loc),
('scale', scale))
self._use_low_high = low is not None and high is not None
self._use_loc_scale = loc is not None and scale is not None
if not (self._use_low_high ^ self._use_loc_scale):
raise ValueError(
'Either `low, high` or `loc, scale` (not both) must have a '
'value.')
self.__low = low
self.__high = high
self.__loc = loc
self.__scale = scale
@cache.cached_property
def low(self):
if self._use_low_high:
return chainer.as_variable(self.__low)
else:
return self.loc
@cache.cached_property
def high(self):
if self._use_low_high:
return chainer.as_variable(self.__high)
else:
return self.loc + self.scale
@cache.cached_property
def loc(self):
if self._use_loc_scale:
return chainer.as_variable(self.__loc)
else:
return self.low
@cache.cached_property
def scale(self):
if self._use_loc_scale:
return chainer.as_variable(self.__scale)
else:
return self.high - self.low
@property
def batch_shape(self):
return self.low.shape
def cdf(self, x):
return clip.clip((x - self.loc) / self.scale, 0., 1.)
@cache.cached_property
def entropy(self):
return exponential.log(self.scale)
@property
def event_shape(self):
return ()
def icdf(self, x):
return x * self.scale + self.loc
def log_prob(self, x):
if not isinstance(x, chainer.Variable):
x = chainer.Variable(x)
xp = backend.get_array_module(x)
logp = broadcast.broadcast_to(
-exponential.log(self.scale), x.shape)
return where.where(
utils.force_array(
(x.data >= self.low.data) & (x.data <= self.high.data)),
logp,
xp.array(-xp.inf, logp.dtype))
@cache.cached_property
def mean(self):
return (self.high + self.low) / 2
@property
def params(self):
if self._use_low_high:
return {'low': self.low, 'high': self.high}
else:
return {'loc': self.loc, 'scale': self.scale}
def sample_n(self, n):
xp = backend.get_array_module(self.low)
if xp is cuda.cupy:
eps = xp.random.uniform(
0, 1, (n,) + self.low.shape, dtype=self.low.dtype)
else:
eps = (
xp.random.uniform(0, 1, (n,) + self.low.shape)
.astype(self.low.dtype))
noise = self.icdf(eps)
return noise
@cache.cached_property
def stddev(self):
return sqrt.sqrt(self.variance)
@property
def support(self):
return '[low, high]'
@cache.cached_property
def variance(self):
return self.scale ** 2 / 12
@distribution.register_kl(Uniform, Uniform)
def _kl_uniform_uniform(dist1, dist2):
xp = backend.get_array_module(dist1.low)
is_inf = xp.logical_or(dist1.high.data > dist2.high.data,
dist1.low.data < dist2.low.data)
kl = (- exponential.log(dist1.high - dist1.low)
+ exponential.log(dist2.high - dist2.low))
inf = xp.array(xp.inf, dist1.high.dtype)
return where.where(is_inf, inf, kl)