Skip to content

Commit

Permalink
Different attempt at FlowDistribution.
Browse files Browse the repository at this point in the history
  • Loading branch information
botev committed Nov 16, 2017
1 parent 50fb854 commit 94e4371
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 16 deletions.
21 changes: 13 additions & 8 deletions examples/normalizing_flows/dlgm_nf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,19 @@ def vae(observed, n, x_dim, z_dim, n_particles):
return model


def q_net(x, z_dim, n_particles):
def q_net(x, z_dim, n_particles, n_planar_flows):
with zs.BayesianNet() as variational:
lz_x = tf.layers.dense(tf.to_float(x), 500, activation=tf.nn.relu)
lz_x = tf.layers.dense(lz_x, 500, activation=tf.nn.relu)
z_mean = tf.layers.dense(lz_x, z_dim)
z_logstd = tf.layers.dense(lz_x, z_dim)
z = zs.Normal('z', z_mean, logstd=z_logstd, group_ndims=1,
n_samples=n_particles)
# z = zs.Normal('z', z_mean, logstd=z_logstd, group_ndims=1,
# n_samples=n_particles)

def flow(samples, log_samples):
return zs.planar_normalizing_flow(samples, log_samples, n_iters=n_planar_flows)

z = zs.NormalFlow('z', flow, z_mean, logstd=z_logstd, group_ndims=1, n_samples=n_particles)
return variational


Expand Down Expand Up @@ -66,14 +71,14 @@ def log_joint(observed):
log_pz, log_px_z = model.local_log_prob(['z', 'x'])
return log_pz + log_px_z

variational = q_net(x, z_dim, n_particles)
variational = q_net(x, z_dim, n_particles, n_planar_flows)
qz_samples, log_qz = variational.query('z', outputs=True,
local_log_prob=True)
# TODO: add tests for repeated calls of flows
qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz,
n_iters=n_planar_flows)
qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz,
n_iters=n_planar_flows)
# qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz,
# n_iters=n_planar_flows)
# qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz,
# n_iters=n_planar_flows)

lower_bound = zs.variational.elbo(log_joint,
observed={'x': x},
Expand Down
78 changes: 77 additions & 1 deletion zhusuan/distributions/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
__all__ = [
'Empirical',
'Implicit',
'FlowDistribution',
]


Expand Down Expand Up @@ -141,4 +142,79 @@ def _prob(self, given):
if self.is_continuous:
return (2 * prob - 1) * inf
else:
return prob
return prob


class FlowDistribution(Distribution):
"""
The class of FlowDistribution distribution.
See :class:`~zhusuan.distributions.base.FlowDistribution` for details.
:param name: A string. The name of the `StochasticTensor`. Must be unique
in the `BayesianNet` context.
:param base: An instance of `Distribution` parametrizing the base distribution.
:param forward: A forward function which describes how we transform the samples
from the base distribution. The signature of the function should be:
transformed, log_det = forward(base_samples)
:param inverse: An inverse function which maps from the transformed samples to
to base samples. The signature of the function should be:
base_samples, log_det = inverse(transformed_samples)
:param group_ndims: A 0-D `int32` Tensor representing the number of
dimensions in `batch_shape` (counted from the end) that are grouped
into a single event, so that their probabilities are calculated
together. Default is 0, which means a single value is an event.
See :class:`~zhusuan.distributions.base.Distribution` for more detailed
explanation.
"""

def __init__(self,
base,
forward,
inverse=None,
group_ndims=0,
**kwargs):
self.base = base
self.forward = forward
self.inverse = inverse
super(FlowDistribution, self).__init__(
dtype=base.dtype,
param_dtype=base.dtype,
is_continuous=base.dtype.is_floating,
group_ndims=group_ndims,
is_reparameterized=False,
**kwargs)

def _value_shape(self):
return self.base.value_shape()

def _get_value_shape(self):
return self.base.get_value_shape()

def _batch_shape(self):
return self.base.batch_shape()

def _get_batch_shape(self):
return self.base.get_batch_shape()

def _sample(self, n_samples):
return self.sample_and_log_prob(n_samples)[0]

def _log_prob(self, given):
if self.inverse is None:
raise ValueError("Flow distribution can only calculate log_prob through `sample_and_log_prob` "
"if `inverse=None`.")
else:
base_given, log_det = self.inverse(given)
log_prob = self.base.log_prob(base_given)
return log_prob - log_det

def _prob(self, given):
return tf.exp(self.log_prob(given))

def sample_and_log_prob(self, n_samples=None):
try:
base_sample, log_prob = self.base.sample_and_log_prob(n_samples)
except:
base_sample = self.base.sample(n_samples)
log_prob = self.base.log_prob(base_sample)
transformed, log_det = self.forward(base_sample, log_prob)
return transformed, log_det
22 changes: 15 additions & 7 deletions zhusuan/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from zhusuan.model.utils import Context
from zhusuan.utils import TensorArithmeticMixin
from zhusuan.distributions import FlowDistribution


__all__ = [
Expand Down Expand Up @@ -114,10 +115,20 @@ def tensor(self):
"StochasticTensor('{}') not compatible "
"with its observed value. Error message: {}".format(
self._name, e))
elif isinstance(self._distribution, FlowDistribution):
self._tensor, self._local_log_prob = self._distribution.\
sample_and_log_prob(self._n_samples)
else:
self._tensor = self.sample(self._n_samples)
return self._tensor

@property
def local_log_prob(self):
tensor = self.tensor
if not hasattr(self, '_local_log_prob'):
self._local_log_prob = self.log_prob(tensor)
return self._local_log_prob

def get_shape(self):
return self.tensor.get_shape()

Expand Down Expand Up @@ -336,15 +347,12 @@ def local_log_prob(self, name_or_names):
:return: A Tensor or a list of Tensors.
"""
self._check_names_exist(name_or_names)
self._check_names_exist(name_or_names)
if isinstance(name_or_names, (tuple, list)):
ret = []
for name in name_or_names:
s_tensor = self._stochastic_tensors[name]
ret.append(s_tensor.log_prob(s_tensor.tensor))
return [self._stochastic_tensors[name].local_log_prob
for name in name_or_names]
else:
s_tensor = self._stochastic_tensors[name_or_names]
ret = s_tensor.log_prob(s_tensor.tensor)
return ret
return self._stochastic_tensors[name_or_names].local_log_prob

def query(self, name_or_names, outputs=False, local_log_prob=False):
"""
Expand Down
57 changes: 57 additions & 0 deletions zhusuan/model/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'Concrete',
'Empirical',
'Implicit',
'NormalFlow',
]


Expand Down Expand Up @@ -884,3 +885,59 @@ def __init__(self,
**kwargs
)
super(Implicit, self).__init__(name, norm, n_samples)


class NormalFlow(StochasticTensor):
"""
The class of univariate Normal `StochasticTensor`.
See :class:`~zhusuan.model.base.StochasticTensor` for details.
.. warning::
The order of arguments `logstd`/`std` will change to `std`/`logstd`
in the coming version.
:param name: A string. The name of the `StochasticTensor`. Must be unique
in the `BayesianNet` context.
:param mean: A `float` Tensor. The mean of the Normal distribution.
Should be broadcastable to match `logstd`.
:param logstd: A `float` Tensor. The log standard deviation of the Normal
distribution. Should be broadcastable to match `mean`.
:param std: A `float` Tensor. The standard deviation of the Normal
distribution. Should be positive and broadcastable to match `mean`.
:param n_samples: A 0-D `int32` Tensor or None. Number of samples
generated by this `StochasticTensor`.
:param group_ndims: A 0-D `int32` Tensor representing the number of
dimensions in `batch_shape` (counted from the end) that are grouped
into a single event, so that their probabilities are calculated
together. Default is 0, which means a single value is an event.
See :class:`~zhusuan.distributions.base.Distribution` for more detailed
explanation.
:param is_reparameterized: A Bool. If True, gradients on samples from this
`StochasticTensor` are allowed to propagate into inputs, using the
reparametrization trick from (Kingma, 2013).
:param check_numerics: Bool. Whether to check numeric issues.
"""

def __init__(self,
name,
forward,
mean=0.,
logstd=None,
std=None,
n_samples=None,
group_ndims=0,
is_reparameterized=True,
check_numerics=False,
**kwargs):
norm = distributions.Normal(
mean,
logstd=logstd,
std=std,
group_ndims=group_ndims,
is_reparameterized=is_reparameterized,
check_numerics=check_numerics,
**kwargs
)
flow = distributions.FlowDistribution(norm, forward, group_ndims=group_ndims)
super(NormalFlow, self).__init__(name, flow, n_samples)

0 comments on commit 94e4371

Please sign in to comment.