Skip to content

Commit

Permalink
update ParamMixture and its tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Apr 8, 2017
1 parent 4e745b2 commit f6ec255
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 75 deletions.
63 changes: 30 additions & 33 deletions edward/models/param_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self,
parameters = locals()
parameters.pop("self")
values = [mixing_weights] + list(six.itervalues(component_params))
with tf.name_scope(name, values=values) as ns:
with tf.name_scope(name, values=values):
if validate_args:
if not isinstance(component_params, dict):
raise TypeError("component_params must be a dict.")
Expand All @@ -74,7 +74,7 @@ def __init__(self,

sample_shape = kwargs.get('sample_shape', ())
self._mixing_weights = tf.identity(mixing_weights, name="mixing_weights")
self._cat = Categorical(p=self._mixing_weights,
self._cat = Categorical(probs=self._mixing_weights,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
sample_shape=sample_shape)
Expand All @@ -86,24 +86,24 @@ def __init__(self,

if validate_args:
if not self._mixing_weights.shape[-1].is_compatible_with(
self._components.get_batch_shape()[0]):
self._components.batch_shape[0]):
raise TypeError("Last dimension of mixing_weights must match with "
"the first dimension of components.")
elif not self._mixing_weights.shape[:-1].is_compatible_with(
self._components.get_batch_shape()[1:]):
self._components.batch_shape[1:]):
raise TypeError("Dimensions of mixing_weights are not compatible "
"with the dimensions of components.")

try:
self._num_components = self._cat.p.shape.as_list()[-1]
self._num_components = self._cat.probs.shape.as_list()[-1]
except: # if p has TensorShape None
raise NotImplementedError("Number of components must be statically "
"determined.")

self._mean_val = None
self._variance_val = None
self._stddev_val = None
if self._cat.p.shape.ndims <= 1:
if self._cat.probs.shape.ndims <= 1:
with tf.name_scope('means'):
try:
comp_means = self._components.mean()
Expand All @@ -113,8 +113,8 @@ def __init__(self,
# weights has shape batch_shape + [num_components]; change
# to broadcast with [num_components] + batch_shape + event_shape.
# The below reshaping only works for empty batch_shape.
weights = self._cat.p
event_rank = self._components.get_event_shape().ndims
weights = self._cat.probs
event_rank = self._components.event_shape.ndims
for _ in range(event_rank):
weights = tf.expand_dims(weights, -1)

Expand All @@ -130,15 +130,15 @@ def __init__(self,
# This fails if _components.{mean,variance}() fails.
pass

super(ParamMixture, self).__init__(
dtype=self._components.dtype,
is_continuous=self._components.is_continuous,
is_reparameterized=False,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=ns,
*args, **kwargs)
super(ParamMixture, self).__init__(
dtype=self._components.dtype,
reparameterization_type=self._components.reparameterization_type,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._cat.value(), self._components.value()],
name=name,
*args, **kwargs)

@property
def cat(self):
Expand All @@ -152,26 +152,25 @@ def components(self):
def num_components(self):
return self._num_components

def _batch_shape_tensor(self):
return self.cat.batch_shape_tensor()

def _batch_shape(self):
return self.cat.batch_shape()
return self.cat.batch_shape

def _get_batch_shape(self):
return self.cat.get_batch_shape()
def _event_shape_tensor(self):
return self.components.event_shape_tensor()

def _event_shape(self):
return self.components.event_shape()

def _get_event_shape(self):
return self.components.get_event_shape()
return self.components.event_shape

# # This will work in TF 1.1
# @distribution_util.AppendDocstring(
# 'Note that this function returns the conditional log probability of the '
# 'observed variable given the categorical variable `cat`. For the '
# 'marginal log probability, use `marginal_log_prob()`.')
def _log_prob(self, x, conjugate=False, **kwargs):
batch_event_rank = (self.get_event_shape().ndims +
self.get_batch_shape().ndims)
batch_event_rank = self.event_shape.ndims + self.batch_shape.ndims
expanded_x = tf.expand_dims(x, -1 - batch_event_rank)
if conjugate:
log_probs = self.components.conjugate_log_prob(expanded_x)
Expand All @@ -188,14 +187,13 @@ def conjugate_log_prob(self):

def marginal_log_prob(self, x, **kwargs):
'The marginal log probability of the observed variable. Sums out `cat`.'
batch_event_rank = (self.get_event_shape().ndims +
self.get_batch_shape().ndims)
batch_event_rank = self.event_shape.ndims + self.batch_shape.ndims
expanded_x = tf.expand_dims(x, -1 - batch_event_rank)
log_probs = self.components.log_prob(expanded_x)

p_ndims = self.cat.p.shape.ndims
p_ndims = self.cat.probs.shape.ndims
perm = tf.concat([[p_ndims - 1], tf.range(p_ndims - 1)], 0)
transposed_p = tf.transpose(self.cat.p, perm)
transposed_p = tf.transpose(self.cat.probs, perm)

return tf.reduce_logsumexp(log_probs + tf.log(transposed_p),
-1 - batch_event_rank)
Expand All @@ -213,8 +211,7 @@ def _sample_n(self, n, seed=None):
cat_sample = tf.expand_dims(cat_sample, 0)

# TODO avoid sampling n per component
batch_event_rank = (self.get_event_shape().ndims +
self.get_batch_shape().ndims)
batch_event_rank = self.event_shape.ndims + self.batch_shape.ndims
cat_axis = comp_sample.shape.ndims - 1 - batch_event_rank
selecter = tf.one_hot(cat_sample, self.num_components,
axis=cat_axis, dtype=self.dtype)
Expand All @@ -234,7 +231,7 @@ def _mean(self):

return self._mean_val

def _std(self):
def _stddev(self):
if self._stddev_val is None:
raise NotImplementedError()

Expand Down
50 changes: 27 additions & 23 deletions tests/test-models/test_param_mixture_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,49 @@ def _test(self, n, *args, **kwargs):
rv = ParamMixture(*args, **kwargs)
val_est = rv.sample(n).shape
val_true = tf.TensorShape(n).concatenate(
rv.cat.get_batch_shape()).concatenate(rv.components.get_event_shape())
rv.cat.batch_shape).concatenate(rv.components.event_shape)
self.assertEqual(val_est, val_true)

self.assertEqual(rv.get_sample_shape(), rv.cat.get_sample_shape())
self.assertEqual(rv.get_sample_shape(), rv.components.get_sample_shape())
self.assertEqual(rv.get_batch_shape(), rv.cat.get_batch_shape())
self.assertEqual(rv.get_event_shape(), rv.components.get_event_shape())
self.assertEqual(rv.sample_shape, rv.cat.sample_shape)
self.assertEqual(rv.sample_shape, rv.components.sample_shape)
self.assertEqual(rv.batch_shape, rv.cat.batch_shape)
self.assertEqual(rv.event_shape, rv.components.event_shape)

def test_batch_0d_event_0d(self):
"""Mixture of 3 normal distributions."""
with self.test_session():
pi = np.array([0.2, 0.3, 0.5], np.float32)
mu = np.array([1.0, 5.0, 7.0], np.float32)
sigma = np.array([1.5, 1.5, 1.5], np.float32)
probs = np.array([0.2, 0.3, 0.5], np.float32)
loc = np.array([1.0, 5.0, 7.0], np.float32)
scale = np.array([1.5, 1.5, 1.5], np.float32)

self._test([], pi, {'mu': mu, 'sigma': sigma}, Normal)
self._test([5], pi, {'mu': mu, 'sigma': sigma}, Normal)
self._test([], probs, {'loc': loc, 'scale': scale}, Normal)
self._test([5], probs, {'loc': loc, 'scale': scale}, Normal)

def test_batch_0d_event_1d(self):
"""Mixture of 2 Dirichlet distributions."""
with self.test_session():
pi = np.array([0.4, 0.6], np.float32)
alpha = np.ones([2, 3], np.float32)
probs = np.array([0.4, 0.6], np.float32)
concentration = np.ones([2, 3], np.float32)

self._test([], pi, {'alpha': alpha}, Dirichlet)
self._test([5], pi, {'alpha': alpha}, Dirichlet)
self._test([], probs, {'concentration': concentration}, Dirichlet)
self._test([5], probs, {'concentration': concentration}, Dirichlet)

def test_batch_1d_event_0d(self):
"""Two mixtures each of 3 beta distributions."""
with self.test_session():
pi = np.array([[0.2, 0.3, 0.5], [0.2, 0.3, 0.5]], np.float32)
a = np.array([[2.0, 0.5], [1.0, 1.0], [0.5, 2.0]], np.float32)
b = a + 2.0

self._test([], pi, {'a': a, 'b': b}, Beta)
self._test([5], pi, {'a': a, 'b': b}, Beta)

pi = np.array([0.2, 0.3, 0.5], np.float32)
self.assertRaises(ValueError, self._test, [], pi, {'a': a, 'b': b}, Beta)
probs = np.array([[0.2, 0.3, 0.5], [0.2, 0.3, 0.5]], np.float32)
conc1 = np.array([[2.0, 0.5], [1.0, 1.0], [0.5, 2.0]], np.float32)
conc0 = conc1 + 2.0

self._test([], probs, {'concentration1': conc1, 'concentration0': conc0},
Beta)
self._test([5], probs, {'concentration1': conc1, 'concentration0': conc0},
Beta)

probs = np.array([0.2, 0.3, 0.5], np.float32)
self.assertRaises(ValueError, self._test, [], probs,
{'concentration1': conc1, 'concentration0': conc0},
Beta)

if __name__ == '__main__':
tf.test.main()
42 changes: 23 additions & 19 deletions tests/test-models/test_param_mixture_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@ def _make_histograms(values, hists, hist_centers, x_axis, n_bins):

class test_param_mixture_class(tf.test.TestCase):

def _test(self, pi, params, dist):
def _test(self, probs, params, dist):
g = tf.Graph()
with g.as_default():
tf.set_random_seed(10003)

N = 50000

x = ParamMixture(pi, params, dist, sample_shape=N)
x = ParamMixture(probs, params, dist, sample_shape=N)
cat = x.cat
components = x.components

marginal_logp = x.marginal_log_prob(x)
cond_logp = x.log_prob(x)

comp_means = components.mean()
comp_stddevs = components.std()
comp_stddevs = components.stddev()
marginal_mean = x.mean()
marginal_stddev = x.std()
marginal_stddev = x.stddev()
marginal_var = x.variance()

sess = self.test_session(graph=g)
Expand All @@ -62,7 +62,7 @@ def _test(self, pi, params, dist):
# Test that per-component statistics are reasonable
for k in range(x.num_components):
selector = (vals[cat] == k)
self.assertAllClose(selector.mean(), pi[k], rtol=0.01, atol=0.01)
self.assertAllClose(selector.mean(), probs[k], rtol=0.01, atol=0.01)
x_k = vals[x][selector]
self.assertAllClose(x_k.mean(0), vals[comp_means][k],
rtol=0.05, atol=0.05)
Expand All @@ -85,7 +85,7 @@ def _test(self, pi, params, dist):
self.assertLess(abs(x_pseudo_hist - x_hists).sum(0).mean(), 0.1)

# Test that histograms match conditional log prob
for k in range(pi.shape[-1]):
for k in range(probs.shape[-1]):
k_cat = k + np.zeros(x_axis.shape, np.int32)
x_vals_k = sess.run(x, {cat: k_cat, components: vals[components]})
_make_histograms(x_vals_k, x_hists, hist_centers, x_axis, n_bins)
Expand All @@ -99,29 +99,33 @@ def _test(self, pi, params, dist):

def test_normal(self):
"""Mixture of 3 normal distributions."""
pi = np.array([0.2, 0.3, 0.5], np.float32)
mu = np.array([1.0, 5.0, 7.0], np.float32)
sigma = np.array([1.5, 1.5, 1.5], np.float32)
probs = np.array([0.2, 0.3, 0.5], np.float32)
loc = np.array([1.0, 5.0, 7.0], np.float32)
scale = np.array([1.5, 1.5, 1.5], np.float32)

self._test(pi, {'mu': mu, 'sigma': sigma}, Normal)
self._test(probs, {'loc': loc, 'scale': scale}, Normal)

def test_beta(self):
"""Mixture of 3 beta distributions."""
pi = np.array([0.2, 0.3, 0.5], np.float32)
a = np.array([2.0, 1.0, 0.5], np.float32)
b = a + 2.0
probs = np.array([0.2, 0.3, 0.5], np.float32)
conc1 = np.array([2.0, 1.0, 0.5], np.float32)
conc0 = conc1 + 2.0

self._test(pi, {'a': a, 'b': b}, Beta)
self._test(probs, {'concentration1': conc1, 'concentration0': conc0},
Beta)

def test_batch_beta(self):
"""Two mixtures of 3 beta distributions."""
pi = np.array([[0.2, 0.3, 0.5], [0.2, 0.3, 0.5]], np.float32)
a = np.array([[2.0, 0.5], [1.0, 1.0], [0.5, 2.0]], np.float32)
b = a + 2.0
probs = np.array([[0.2, 0.3, 0.5], [0.2, 0.3, 0.5]], np.float32)
conc1 = np.array([[2.0, 0.5], [1.0, 1.0], [0.5, 2.0]], np.float32)
conc0 = conc1 + 2.0

# self._test(pi, {'a': a, 'b': b}, Beta)
# self._test(probs, {'concentration1': conc1, 'concentration0': conc0},
# Beta)
self.assertRaises(NotImplementedError,
self._test, pi, {'a': a, 'b': b}, Beta)
self._test, probs,
{'concentration1': conc1, 'concentration0': conc0},
Beta)

if __name__ == '__main__':
tf.test.main()

0 comments on commit f6ec255

Please sign in to comment.