Skip to content

Commit

Permalink
replace DP greek args with english args
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 26, 2017
1 parent 4859ea6 commit 17a762b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 68 deletions.
113 changes: 57 additions & 56 deletions edward/models/dirichlet_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
class DirichletProcess(RandomVariable, Distribution):
"""Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`.
It has two parameters: a positive real value :math:`\\alpha`,
known as the concentration parameter (``alpha``), and a base
It has two parameters: a positive real value :math:`\\alpha`, known
as the concentration parameter (``concentration``), and a base
distribution :math:`H` (``base``).
"""
def __init__(self, alpha, base, validate_args=False, allow_nan_stats=True,
name="DirichletProcess", *args, **kwargs):
def __init__(self, concentration, base, validate_args=False,
allow_nan_stats=True, name="DirichletProcess", *args, **kwargs):
"""Initialize a batch of Dirichlet processes.
Parameters
----------
alpha : tf.Tensor
concentration : tf.Tensor
Concentration parameter. Must be positive real-valued. Its shape
determines the number of independent DPs (batch shape).
base : RandomVariable
Expand All @@ -46,67 +46,68 @@ def __init__(self, alpha, base, validate_args=False, allow_nan_stats=True,
>>> assert dp.shape == (2, 5, 3)
"""
parameters = locals()
with tf.name_scope(name, values=[alpha]):
with tf.name_scope(name, values=[concentration]):
with tf.control_dependencies([
tf.assert_positive(alpha),
tf.assert_positive(concentration),
] if validate_args else []):
if validate_args and isinstance(base, RandomVariable):
raise TypeError("base must be a ed.RandomVariable object.")

self._alpha = tf.identity(alpha, name="alpha")
self._concentration = tf.identity(concentration, name="concentration")
self._base = base

# Create empty tensor to store future atoms.
self._theta = tf.zeros(
# Form empty tensor to store atom locations.
self._locs = tf.zeros(
[0] + self.batch_shape.as_list() + self.event_shape.as_list(),
dtype=self._base.dtype)

# Instantiate beta distribution for stick breaking proportions.
self._betadist = Beta(tf.ones_like(self._alpha), self._alpha)
# Create empty tensor to store stick breaking proportions.
self._beta = tf.zeros(
# Instantiate distribution to draw mixing proportions.
self._probs_dist = Beta(tf.ones_like(self._concentration),
self._concentration)
# Form empty tensor to store mixing proportions.
self._probs = tf.zeros(
[0] + self.batch_shape.as_list(),
dtype=self._betadist.dtype)
dtype=self._probs_dist.dtype)

super(DirichletProcess, self).__init__(
dtype=tf.int32,
reparameterization_type=NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._alpha, self._beta, self._theta],
graph_parents=[self._concentration, self._locs, self._probs],
name=name,
*args, **kwargs)

@property
def alpha(self):
"""Concentration parameter."""
return self._alpha

@property
def base(self):
"""Base distribution used for drawing the atoms."""
"""Base distribution used for drawing the atom locations."""
return self._base

@property
def beta(self):
"""Stick breaking proportions. It has shape [None] + batch_shape, where
the first dimension is the number of atoms, instantiated only as
needed."""
return self._beta
def concentration(self):
"""Concentration parameter."""
return self._concentration

@property
def locs(self):
"""Atom locations. It has shape [None] + batch_shape +
event_shape, where the first dimension is the number of atoms,
instantiated only as needed."""
return self._locs

@property
def theta(self):
"""Atoms. It has shape [None] + batch_shape + event_shape, where
def probs(self):
"""Mixing proportions. It has shape [None] + batch_shape, where
the first dimension is the number of atoms, instantiated only as
needed."""
return self._theta
return self._probs

def _batch_shape_tensor(self):
return tf.shape(self.alpha)
return tf.shape(self.concentration)

def _batch_shape(self):
return self.alpha.shape
return self.concentration.shape

def _event_shape_tensor(self):
return tf.shape(self.base)
Expand Down Expand Up @@ -159,44 +160,44 @@ def _sample_n(self, n, seed=None):
# Initialize all samples as zero, they will be overwritten in any case
draws = tf.zeros([n] + batch_shape + event_shape, dtype=self.base.dtype)

# Calculate shape invariance conditions for theta and beta as these
# Calculate shape invariance conditions for locs and probs as these
# can change shape between loop iterations.
theta_shape = tf.TensorShape([None])
beta_shape = tf.TensorShape([None])
if len(self.theta.shape) > 1:
theta_shape = theta_shape.concatenate(self.theta.shape[1:])
beta_shape = beta_shape.concatenate(self.beta.shape[1:])
locs_shape = tf.TensorShape([None])
probs_shape = tf.TensorShape([None])
if len(self.locs.shape) > 1:
locs_shape = locs_shape.concatenate(self.locs.shape[1:])
probs_shape = probs_shape.concatenate(self.probs.shape[1:])

# While we have not broken enough sticks, keep sampling.
_, _, self._theta, self._beta, samples = tf.while_loop(
_, _, self._locs, self._probs, samples = tf.while_loop(
self._sample_n_cond, self._sample_n_body,
loop_vars=[k, bools, self.theta, self.beta, draws],
loop_vars=[k, bools, self.locs, self.probs, draws],
shape_invariants=[
k.shape, bools.shape, theta_shape, beta_shape, draws.shape])
k.shape, bools.shape, locs_shape, probs_shape, draws.shape])

return samples

def _sample_n_cond(self, k, bools, theta, beta, draws):
def _sample_n_cond(self, k, bools, locs, probs, draws):
# Proceed if at least one bool is True.
return tf.reduce_any(bools)

def _sample_n_body(self, k, bools, theta, beta, draws):
def _sample_n_body(self, k, bools, locs, probs, draws):
n, batch_shape, event_shape, rank = self._temp_scope

# If necessary, break a new piece of stick, i.e.
# add a new persistent atom to theta and sample another beta
theta, beta = tf.cond(
tf.shape(theta)[0] - 1 >= k,
lambda: (theta, beta),
# add a new persistent atom location and weight.
locs, probs = tf.cond(
tf.shape(locs)[0] - 1 >= k,
lambda: (locs, probs),
lambda: (
tf.concat(
[theta, tf.expand_dims(self.base.sample(batch_shape), 0)], 0),
[locs, tf.expand_dims(self.base.sample(batch_shape), 0)], 0),
tf.concat(
[beta, tf.expand_dims(self._betadist.sample(), 0)], 0)))
theta_k = tf.gather(theta, k)
beta_k = tf.gather(beta, k)
[probs, tf.expand_dims(self._probs_dist.sample(), 0)], 0)))
locs_k = tf.gather(locs, k)
probs_k = tf.gather(probs, k)

# Assign True samples to the new theta_k.
# Assign True samples to the new locs_k.
if len(bools.shape) <= 1:
bools_tile = bools
else:
Expand All @@ -208,12 +209,12 @@ def _sample_n_body(self, k, bools, theta, beta, draws):
bools, [n] + batch_shape + [1] * len(event_shape)),
[1] + [1] * len(batch_shape) + event_shape)

theta_k_tile = tf.tile(tf.expand_dims(theta_k, 0), [n] + [1] * (rank - 1))
draws = tf.where(bools_tile, theta_k_tile, draws)
locs_k_tile = tf.tile(tf.expand_dims(locs_k, 0), [n] + [1] * (rank - 1))
draws = tf.where(bools_tile, locs_k_tile, draws)

# Flip coins according to stick probabilities.
flips = Bernoulli(beta_k).sample(n)
flips = Bernoulli(probs_k).sample(n)
# If coin lands heads, assign sample's corresponding bool to False
# (this ends its "while loop").
bools = tf.where(tf.cast(flips, tf.bool), tf.zeros_like(bools), bools)
return k + 1, bools, theta, beta, draws
return k + 1, bools, locs, probs, draws
4 changes: 2 additions & 2 deletions examples/pp_dirichlet_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def body(k, beta_k):
return stick_num


dp = dirichlet_process(alpha=10.0)
dp = dirichlet_process(10.0)

# The number of sticks broken is dynamic, changing across evaluations.
sess = tf.Session()
print(sess.run(dp))
print(sess.run(dp))

# Demo of the DirichletProcess random variable in Edward.
base = Normal(mu=0.0, sigma=1.0)
base = Normal(0.0, 1.0)

# Highly concentrated DP.
alpha = 1.0
Expand Down
20 changes: 10 additions & 10 deletions tests/test-models/test_dirichlet_process_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,32 @@

class test_dirichletprocess_sample_class(tf.test.TestCase):

def _test(self, n, alpha, base):
x = DirichletProcess(alpha=alpha, base=base)
def _test(self, n, concentration, base):
x = DirichletProcess(concentration=concentration, base=base)
val_est = x.sample(n).shape.as_list()
val_true = n + tf.convert_to_tensor(alpha).shape.as_list() + \
val_true = n + tf.convert_to_tensor(concentration).shape.as_list() + \
tf.convert_to_tensor(base).shape.as_list()
self.assertEqual(val_est, val_true)

def test_alpha_0d_base_0d(self):
def test_concentration_0d_base_0d(self):
with self.test_session():
self._test([1], 0.5, Normal(loc=0.0, scale=0.5))
self._test([5], tf.constant(0.5), Normal(loc=0.0, scale=0.5))

def test_alpha_1d_base0d(self):
def test_concentration_1d_base_0d(self):
with self.test_session():
self._test([1], np.array([0.5]), Normal(loc=0.0, scale=0.5))
self._test([5], tf.constant([0.5]), Normal(loc=0.0, scale=0.5))
self._test([1], tf.constant([0.2, 1.5]), Normal(loc=0.0, scale=0.5))
self._test([5], tf.constant([0.2, 1.5]), Normal(loc=0.0, scale=0.5))

def test_alpha_0d_base1d(self):
def test_concentration_0d_base_1d(self):
with self.test_session():
self._test([1], 0.5, Normal(loc=tf.zeros(3), scale=tf.ones(3)))
self._test([5], tf.constant(0.5),
Normal(loc=tf.zeros(3), scale=tf.ones(3)))

def test_alpha_1d_base2d(self):
def test_concentration_1d_base_2d(self):
with self.test_session():
self._test([1], np.array([0.5]),
Normal(loc=tf.zeros([3, 4]), scale=tf.ones([3, 4])))
Expand All @@ -51,11 +51,11 @@ def test_persistent_state(self):
dp = DirichletProcess(0.1, Normal(loc=0.0, scale=1.0))
x = dp.sample(5)
y = dp.sample(5)
x_data, y_data, theta = sess.run([x, y, dp.theta])
x_data, y_data, locs = sess.run([x, y, dp.locs])
for sample in x_data:
self.assertTrue(sample in theta)
self.assertTrue(sample in locs)
for sample in y_data:
self.assertTrue(sample in theta)
self.assertTrue(sample in locs)

if __name__ == '__main__':
tf.test.main()

0 comments on commit 17a762b

Please sign in to comment.