Skip to content

Commit

Permalink
Simplify inputs to DirichletProcess; improve docstrings (#583)
Browse files Browse the repository at this point in the history
* add validate_args and simplify DP __init__; update docstrings

* update docstrings for init
  • Loading branch information
dustinvtran committed Mar 24, 2017
1 parent ceb537d commit 7746b67
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 92 deletions.
3 changes: 1 addition & 2 deletions docs/tex/iclr2017.tex
Expand Up @@ -333,8 +333,7 @@ \subsubsection{Appendix A. Model Examples}
N = 1000 # number of data points
D = 5 # data dimensionality

dp = DirichletProcess(
alpha=1.0, base_cls=Normal, mu=tf.zeros(D), sigma=tf.ones(D))
dp = DirichletProcess(alpha=1.0, base=Normal(mu=tf.zeros(D), sigma=tf.ones(D)))
mu = dp.sample(N)
x = Normal(mu=mu, sigma=tf.ones([N, D]))
\end{lstlisting}
Expand Down
76 changes: 41 additions & 35 deletions edward/models/dirichlet_process.py
Expand Up @@ -14,47 +14,48 @@


class DirichletProcess(RandomVariable, Distribution):
def __init__(self, alpha, base_cls, validate_args=False, allow_nan_stats=True,
name="DirichletProcess", value=None, *args, **kwargs):
"""Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`.
"""Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`.
It has two parameters: a positive real value :math:`\\alpha`,
known as the concentration parameter (``alpha``), and a base
distribution :math:`H` (``base_cls(*args, **kwargs)``).
It has two parameters: a positive real value :math:`\\alpha`,
known as the concentration parameter (``alpha``), and a base
distribution :math:`H` (``base``).
"""
def __init__(self, alpha, base, validate_args=False, allow_nan_stats=True,
name="DirichletProcess", *args, **kwargs):
"""Initialize a batch of Dirichlet processes.
Parameters
----------
alpha : tf.Tensor
Concentration parameter. Must be positive real-valued. Its shape
determines the number of independent DPs (batch shape).
base_cls : RandomVariable
Class of base distribution. Its shape (when instantiated)
determines the shape of an individual DP (event shape).
*args, **kwargs : optional
Arguments passed into ``base_cls``.
base : RandomVariable
Base distribution. Its shape determines the shape of an
individual DP (event shape).
Examples
--------
>>> # scalar concentration parameter, scalar base distribution
>>> dp = DirichletProcess(0.1, Normal, mu=0.0, sigma=1.0)
>>> dp = DirichletProcess(0.1, Normal(mu=0.0, sigma=1.0))
>>> assert dp.shape == ()
>>>
>>> # vector of concentration parameters, matrix of Exponentials
>>> dp = DirichletProcess(tf.constant([0.1, 0.4]),
... Exponential, lam=tf.ones([5, 3]))
... Exponential(lam=tf.ones([5, 3])))
>>> assert dp.shape == (2, 5, 3)
"""
parameters = locals()
parameters.pop("self")
with tf.name_scope(name, values=[alpha]) as ns:
with tf.control_dependencies([]):
with tf.name_scope(name, values=[alpha]):
with tf.control_dependencies([
tf.assert_positive(alpha),
] if validate_args else []):
if validate_args and isinstance(base, RandomVariable):
raise TypeError("base must be a ed.RandomVariable object.")

self._alpha = tf.identity(alpha, name="alpha")
self._base_cls = base_cls
self._base_args = args
self._base_kwargs = kwargs
self._base = base

# Instantiate base distribution.
self._base = self._base_cls(*self._base_args, **self._base_kwargs)
# Create empty tensor to store future atoms.
self._theta = tf.zeros(
[0] +
Expand All @@ -63,28 +64,33 @@ def __init__(self, alpha, base_cls, validate_args=False, allow_nan_stats=True,
dtype=self._base.dtype)

# Instantiate beta distribution for stick breaking proportions.
self._betadist = Beta(a=tf.ones_like(self.alpha), b=self.alpha)
self._betadist = Beta(a=tf.ones_like(self._alpha), b=self._alpha)
# Create empty tensor to store stick breaking proportions.
self._beta = tf.zeros(
[0] + self.get_batch_shape().as_list(),
dtype=self._betadist.dtype)

super(DirichletProcess, self).__init__(
dtype=tf.int32,
is_continuous=False,
is_reparameterized=False,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._alpha, self._beta, self._theta],
name=ns,
value=value)
super(DirichletProcess, self).__init__(
dtype=tf.int32,
is_continuous=False,
is_reparameterized=False,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._alpha, self._beta, self._theta],
name=name,
*args, **kwargs)

@property
def alpha(self):
"""Concentration parameter."""
return self._alpha

@property
def base(self):
"""Base distribution used for drawing the atoms."""
return self._base

@property
def beta(self):
"""Stick breaking proportions. It has shape [None] + batch_shape, where
Expand All @@ -106,10 +112,10 @@ def _get_batch_shape(self):
return self.alpha.shape

def _event_shape(self):
return tf.shape(self._base)
return tf.shape(self.base)

def _get_event_shape(self):
return self._base.shape
return self.base.shape

def _sample_n(self, n, seed=None):
"""Sample ``n`` draws from the DP. Draws from the base
Expand Down Expand Up @@ -154,7 +160,7 @@ def _sample_n(self, n, seed=None):
bools = tf.ones([n] + batch_shape, dtype=tf.bool)

# Initialize all samples as zero, they will be overwritten in any case
draws = tf.zeros([n] + batch_shape + event_shape, dtype=self._base.dtype)
draws = tf.zeros([n] + batch_shape + event_shape, dtype=self.base.dtype)

# Calculate shape invariance conditions for theta and beta as these
# can change shape between loop iterations.
Expand Down Expand Up @@ -187,7 +193,7 @@ def _sample_n_body(self, k, bools, theta, beta, draws):
lambda: (theta, beta),
lambda: (
tf.concat(
[theta, tf.expand_dims(self._base.sample(batch_shape), 0)], 0),
[theta, tf.expand_dims(self.base.sample(batch_shape), 0)], 0),
tf.concat(
[beta, tf.expand_dims(self._betadist.sample(), 0)], 0)))
theta_k = tf.gather(theta, k)
Expand Down
40 changes: 29 additions & 11 deletions edward/models/empirical.py
Expand Up @@ -12,26 +12,44 @@ class Empirical(RandomVariable, Distribution):
"""Empirical random variable."""
def __init__(self, params, validate_args=False, allow_nan_stats=True,
name="Empirical", *args, **kwargs):
"""Initialize an ``Empirical`` random variable.
Parameters
----------
params : tf.Tensor
Collection of samples. Its outer (left-most) dimension
determines the number of samples.
Examples
--------
>>> # 100 samples of a scalar
>>> x = Empirical(params=tf.zeros(100))
>>> assert x.shape == ()
>>>
>>> # 5 samples of a 2 x 3 matrix
>>> dp = Empirical(params=tf.zeros([5, 2, 3]))
>>> assert x.shape == (2, 3)
"""
parameters = locals()
parameters.pop("self")
with tf.name_scope(name, values=[params]) as ns:
with tf.name_scope(name, values=[params]):
with tf.control_dependencies([]):
self._params = tf.identity(params, name="params")
try:
self._n = tf.shape(self._params)[0]
except ValueError: # scalar params
self._n = tf.constant(1)

super(Empirical, self).__init__(
dtype=self._params.dtype,
is_continuous=False,
is_reparameterized=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._params, self._n],
name=ns,
*args, **kwargs)
super(Empirical, self).__init__(
dtype=self._params.dtype,
is_continuous=False,
is_reparameterized=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._params, self._n],
name=name,
*args, **kwargs)

@staticmethod
def _param_shapes(sample_shape):
Expand Down
40 changes: 29 additions & 11 deletions edward/models/point_mass.py
Expand Up @@ -16,21 +16,39 @@ class PointMass(RandomVariable, Distribution):
"""
def __init__(self, params, validate_args=False, allow_nan_stats=True,
name="PointMass", *args, **kwargs):
"""Initialize a ``PointMass`` random variable.
Parameters
----------
params : tf.Tensor
The location with all probability mass.
Examples
--------
>>> # scalar
>>> x = PointMass(params=28.3)
>>> assert x.shape == ()
>>>
>>> # 5 x 2 x 3 tensor
>>> dp = PointMass(params=tf.zeros([5, 2, 3]))
>>> assert x.shape == (5, 2, 3)
"""
parameters = locals()
parameters.pop("self")
with tf.name_scope(name, values=[params]) as ns:
with tf.name_scope(name, values=[params]):
with tf.control_dependencies([]):
self._params = tf.identity(params, name="params")
super(PointMass, self).__init__(
dtype=self._params.dtype,
is_continuous=False,
is_reparameterized=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._params],
name=ns,
*args, **kwargs)

super(PointMass, self).__init__(
dtype=self._params.dtype,
is_continuous=False,
is_reparameterized=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._params],
name=name,
*args, **kwargs)

@staticmethod
def _param_shapes(sample_shape):
Expand Down
19 changes: 9 additions & 10 deletions examples/pp_dirichlet_process.py
Expand Up @@ -46,12 +46,11 @@ def body(k, beta_k):
print(sess.run(dp))

# Demo of the DirichletProcess random variable in Edward.
base_cls = Normal
kwargs = {'mu': 0.0, 'sigma': 1.0}
base = Normal(mu=0.0, sigma=1.0)

# Highly concentrated DP.
alpha = 1.0
dp = DirichletProcess(alpha, base_cls, **kwargs)
dp = DirichletProcess(alpha, base)
x = dp.sample(1000)
samples = sess.run(x)
plt.hist(samples, bins=100, range=(-3.0, 3.0))
Expand All @@ -60,7 +59,7 @@ def body(k, beta_k):

# More spread out DP.
alpha = 50.0
dp = DirichletProcess(alpha, base_cls, **kwargs)
dp = DirichletProcess(alpha, base)
x = dp.sample(1000)
samples = sess.run(x)
plt.hist(samples, bins=100, range=(-3.0, 3.0))
Expand All @@ -69,7 +68,7 @@ def body(k, beta_k):

# States persist across calls to sample() in a DP.
alpha = 1.0
dp = DirichletProcess(alpha, base_cls, **kwargs)
dp = DirichletProcess(alpha, base)
x = dp.sample(50)
y = dp.sample(75)
samples_x, samples_y = sess.run([x, y])
Expand All @@ -82,13 +81,13 @@ def body(k, beta_k):

# ``theta`` is the distribution indirectly returned by the DP.
# Fetching theta is the same as fetching the Dirichlet process.
dp = DirichletProcess(alpha, base_cls, **kwargs)
theta = base_cls(value=tf.cast(dp, tf.float32), **kwargs)
dp = DirichletProcess(alpha, base)
theta = Normal(0.0, 1.0, value=tf.cast(dp, tf.float32))
print(sess.run([dp, theta]))
print(sess.run([dp, theta]))

# DirichletProcess can also take in non-scalar concentrations and bases.
base_cls = Exponential
kwargs = {'lam': tf.ones([5, 2])}
dp = DirichletProcess(tf.constant([0.1, 0.6, 0.4]), base_cls, **kwargs)
alpha = tf.constant([0.1, 0.6, 0.4])
base = Exponential(lam=tf.ones([5, 2]))
dp = DirichletProcess(alpha, base)
print(dp)
3 changes: 1 addition & 2 deletions notebooks/iclr2017.ipynb
Expand Up @@ -518,8 +518,7 @@
"N = 1000 # number of data points\n",
"D = 5 # data dimensionality\n",
"\n",
"dp = DirichletProcess(\n",
" alpha=1.0, base_cls=Normal, mu=tf.zeros(D), sigma=tf.ones(D))\n",
"dp = DirichletProcess(alpha=1.0, base=Normal(mu=tf.zeros(D), sigma=tf.ones(D)))\n",
"mu = dp.sample(N)\n",
"x = Normal(mu=mu, sigma=tf.ones([N, D]))"
]
Expand Down
41 changes: 20 additions & 21 deletions tests/test-models/test_dirichlet_process_sample.py
Expand Up @@ -10,46 +10,45 @@

class test_dirichletprocess_sample_class(tf.test.TestCase):

def _test(self, n, alpha, base_cls, *args, **kwargs):
x = DirichletProcess(alpha=alpha, base_cls=base_cls, *args, **kwargs)
base = base_cls(*args, **kwargs)
def _test(self, n, alpha, base):
x = DirichletProcess(alpha=alpha, base=base)
val_est = x.sample(n).shape.as_list()
val_true = n + tf.convert_to_tensor(alpha).shape.as_list() + \
tf.convert_to_tensor(base).shape.as_list()
self.assertEqual(val_est, val_true)

def test_alpha_0d_base_0d(self):
with self.test_session():
self._test([1], 0.5, Normal, mu=0.0, sigma=0.5)
self._test([5], tf.constant(0.5), Normal, mu=0.0, sigma=0.5)
self._test([1], 0.5, Normal(mu=0.0, sigma=0.5))
self._test([5], tf.constant(0.5), Normal(mu=0.0, sigma=0.5))

def test_alpha_1d_base0d(self):
with self.test_session():
self._test([1], np.array([0.5]), Normal, mu=0.0, sigma=0.5)
self._test([5], tf.constant([0.5]), Normal, mu=0.0, sigma=0.5)
self._test([1], tf.constant([0.2, 1.5]), Normal, mu=0.0, sigma=0.5)
self._test([5], tf.constant([0.2, 1.5]), Normal, mu=0.0, sigma=0.5)
self._test([1], np.array([0.5]), Normal(mu=0.0, sigma=0.5))
self._test([5], tf.constant([0.5]), Normal(mu=0.0, sigma=0.5))
self._test([1], tf.constant([0.2, 1.5]), Normal(mu=0.0, sigma=0.5))
self._test([5], tf.constant([0.2, 1.5]), Normal(mu=0.0, sigma=0.5))

def test_alpha_0d_base1d(self):
with self.test_session():
self._test([1], 0.5, Normal, mu=tf.zeros(3), sigma=tf.ones(3))
self._test([5], tf.constant(0.5), Normal,
mu=tf.zeros(3), sigma=tf.ones(3))
self._test([1], 0.5, Normal(mu=tf.zeros(3), sigma=tf.ones(3)))
self._test([5], tf.constant(0.5),
Normal(mu=tf.zeros(3), sigma=tf.ones(3)))

def test_alpha_1d_base2d(self):
with self.test_session():
self._test([1], np.array([0.5]), Normal,
mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))
self._test([5], tf.constant([0.5]), Normal,
mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))
self._test([1], tf.constant([0.2, 1.5]), Normal,
mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))
self._test([5], tf.constant([0.2, 1.5]), Normal,
mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))
self._test([1], np.array([0.5]),
Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])))
self._test([5], tf.constant([0.5]),
Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])))
self._test([1], tf.constant([0.2, 1.5]),
Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])))
self._test([5], tf.constant([0.2, 1.5]),
Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])))

def test_persistent_state(self):
with self.test_session() as sess:
dp = DirichletProcess(0.1, Normal, mu=0.0, sigma=1.0)
dp = DirichletProcess(0.1, Normal(mu=0.0, sigma=1.0))
x = dp.sample(5)
y = dp.sample(5)
x_data, y_data, theta = sess.run([x, y, dp.theta])
Expand Down

0 comments on commit 7746b67

Please sign in to comment.