forked from blei-lab/edward
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pointmass.py
231 lines (181 loc) · 6.51 KB
/
pointmass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""The Point Mass distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.contrib.distributions.python.ops import \
distribution
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
import tensorflow as tf
class PointMass(distribution.Distribution):
def __init__(self,
params,
validate_args=True,
allow_nan_stats=False,
name="PointMass"):
self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args
with ops.op_scope([params], name):
params = ops.convert_to_tensor(params)
with ops.control_dependencies([]):
self._name = name
self._params = array_ops.identity(params, name="params")
self._batch_shape = self._ones().get_shape()
self._event_shape = tensor_shape.TensorShape([])
@property
def allow_nan_stats(self):
"""Boolean describing behavior when a stat is undefined for batch member."""
return self._allow_nan_stats
@property
def validate_args(self):
"""Boolean describing behavior on invalid input."""
return self._validate_args
@property
def name(self):
return self._name
@property
def dtype(self):
return self._params.dtype
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
independent distributions of this kind the instance represents.
Args:
name: name to give to the op.
Returns:
`Tensor` `batch_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return array_ops.shape(self._ones())
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `batch_shape`. May be only partially defined.
Returns:
batch shape
"""
return self._batch_shape
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
name: name to give to the op.
Returns:
`Tensor` `event_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return constant_op.constant([], dtype=dtypes.int32)
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `event_shape`. May be only partially defined.
Returns:
event shape
"""
return self._event_shape
@property
def params(self):
"""Distribution parameter."""
return self._params
def mean(self, name="mean"):
"""Mean of this distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._params], name):
return self._params
def mode(self, name="mode"):
"""Mode of this distribution."""
return self.mean(name="mode")
def std(self, name="std"):
"""Standard deviation of this distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._params], name):
return 0.0 * array_ops.ones_like(self._params)
def variance(self, name="variance"):
"""Variance of this distribution."""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return math_ops.square(self.std())
def log_prob(self, x, name="log_prob"):
"""Log prob of observations in `x` under these Normal distribution(s).
Args:
x: tensor of dtype `dtype`, paramsst be broadcastable with `params`.
name: The name to give this op.
Returns:
log_prob: tensor of dtype `dtype`, the log-PDFs of `x`.
"""
with ops.name_scope(self.name):
with ops.op_scope([self._params, x], name):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s"
% (x.dtype, self.dtype))
return tf.cast(tf.equal(x, self.params), dtype=self.dtype)
def cdf(self, x, name="cdf"):
"""CDF of observations in `x` under these Normal distribution(s).
Args:
x: tensor of dtype `dtype`, paramsst be broadcastable with `params`.
name: The name to give this op.
Returns:
cdf: tensor of dtype `dtype`, the CDFs of `x`.
"""
raise NotImplementedError()
def log_cdf(self, x, name="log_cdf"):
"""Log CDF of observations `x` under these Normal distribution(s).
Args:
x: tensor of dtype `dtype`, paramsst be broadcastable with `params`.
name: The name to give this op.
Returns:
log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`.
"""
with ops.name_scope(self.name):
with ops.op_scope([self._params, x], name):
return math_ops.log(self.cdf(x))
def prob(self, x, name="prob"):
"""The PDF of observations in `x` under these Normal distribution(s).
Args:
x: tensor of dtype `dtype`, paramsst be broadcastable with `params`.
name: The name to give this op.
Returns:
prob: tensor of dtype `dtype`, the prob values of `x`.
"""
return super(Normal, self).prob(x, name=name)
def entropy(self, name="entropy"):
"""The entropy of Normal distribution(s).
Args:
name: The name to give this op.
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
raise NotImplementedError()
def sample_n(self, n, seed=None, name="sample_n"):
"""Sample `n` observations from the Normal Distributions.
Args:
n: `Scalar`, type int32, the number of observations to sample.
seed: Python integer, the random seed.
name: The name to give this op.
Returns:
samples: `[n, ...]`, a `Tensor` of `n` samples for each
of the distributions determined by broadcasting the hyperparameters.
"""
with ops.name_scope(self.name):
with ops.op_scope([self._params, n], name):
# TODO
n = n.eval()
return tf.pack([self.params] * n)
@property
def is_reparameterized(self):
return True
def _ones(self):
return array_ops.ones_like(self._params)
def _zeros(self):
return array_ops.zeros_like(self._params)
@property
def is_continuous(self):
return False