-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
167 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import numpy as np | ||
import pytest | ||
import tensorflow as tf | ||
|
||
from tests.layers.flows.helper import (QuadraticFlow, quadratic_transform, | ||
npyops, invertible_flow_standard_check) | ||
from tfsnippet import FlowDistribution, Normal | ||
from tfsnippet.layers import InvertFlow, PlanarNormalizingFlow, BaseFlow | ||
|
||
|
||
class InvertFlowTestCase(tf.test.TestCase): | ||
|
||
def test_invert_flow(self): | ||
with self.test_session() as sess: | ||
# test invert a normal flow | ||
flow = QuadraticFlow(2., 5.) | ||
inv_flow = flow.invert() | ||
|
||
self.assertIsInstance(inv_flow, InvertFlow) | ||
self.assertEqual(inv_flow.x_value_ndims, 0) | ||
self.assertEqual(inv_flow.y_value_ndims, 0) | ||
self.assertFalse(inv_flow.require_batch_dims) | ||
|
||
test_x = np.arange(12, dtype=np.float32) + 1. | ||
test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.) | ||
|
||
self.assertFalse(flow._has_built) | ||
y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x)) | ||
self.assertTrue(flow._has_built) | ||
|
||
np.testing.assert_allclose(sess.run(y), test_y) | ||
np.testing.assert_allclose(sess.run(log_det_y), test_log_det) | ||
invertible_flow_standard_check(self, inv_flow, sess, test_y) | ||
|
||
# test invert an InvertFlow | ||
inv_inv_flow = inv_flow.invert() | ||
self.assertIs(inv_inv_flow, flow) | ||
|
||
# test use with FlowDistribution | ||
normal = Normal(mean=1., std=2.) | ||
inv_flow = QuadraticFlow(2., 5.).invert() | ||
distrib = FlowDistribution(normal, inv_flow) | ||
distrib_log_det = distrib.log_prob(test_x) | ||
np.testing.assert_allclose( | ||
*sess.run([distrib_log_det, | ||
normal.log_prob(test_y) + test_log_det]) | ||
) | ||
|
||
def test_property(self): | ||
class _Flow(BaseFlow): | ||
@property | ||
def explicitly_invertible(self): | ||
return True | ||
|
||
inv_flow = _Flow(x_value_ndims=2, y_value_ndims=3, | ||
require_batch_dims=True).invert() | ||
self.assertTrue(inv_flow.require_batch_dims) | ||
self.assertEqual(inv_flow.x_value_ndims, 3) | ||
self.assertEqual(inv_flow.y_value_ndims, 2) | ||
|
||
def test_errors(self): | ||
with pytest.raises(ValueError, match='`flow` must be an explicitly ' | ||
'invertible flow'): | ||
_ = InvertFlow(object()) | ||
|
||
with pytest.raises(ValueError, match='`flow` must be an explicitly ' | ||
'invertible flow'): | ||
_ = PlanarNormalizingFlow().invert() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from .base import BaseFlow | ||
|
||
__all__ = ['InvertFlow'] | ||
|
||
|
||
class InvertFlow(BaseFlow): | ||
""" | ||
Turn a :class:`BaseFlow` into its inverted flow. | ||
This class is particularly useful when the flow is (theoretically) defined | ||
in the opposite direction to the direction of network initialization. | ||
For example, define `z -> x`, but initialized by feeding `x`. | ||
""" | ||
|
||
def __init__(self, flow, name=None, scope=None): | ||
""" | ||
Construct a new :class:`InvertFlow`. | ||
Args: | ||
flow (BaseFlow): The underlying flow. | ||
""" | ||
if not isinstance(flow, BaseFlow) or not flow.explicitly_invertible: | ||
raise ValueError('`flow` must be an explicitly invertible flow: ' | ||
'got {!r}'.format(flow)) | ||
self._flow = flow | ||
|
||
super(InvertFlow, self).__init__( | ||
x_value_ndims=flow.y_value_ndims, | ||
y_value_ndims=flow.x_value_ndims, | ||
require_batch_dims=flow.require_batch_dims, | ||
name=name, | ||
scope=scope | ||
) | ||
|
||
def invert(self): | ||
""" | ||
Get the original flow, inverted by this :class:`InvertFlow`. | ||
Returns: | ||
BaseFlow: The original flow. | ||
""" | ||
return self._flow | ||
|
||
@property | ||
def explicitly_invertible(self): | ||
return True | ||
|
||
def build(self, input=None): # pragma: no cover | ||
# since `flow` should be inverted, we should build `flow` in | ||
# `inverse_transform` rather than in `transform` or `build` | ||
pass | ||
|
||
def transform(self, x, compute_y=True, compute_log_det=True, name=None): | ||
return self._flow.inverse_transform( | ||
y=x, compute_x=compute_y, compute_log_det=compute_log_det, | ||
name=name | ||
) | ||
|
||
def inverse_transform(self, y, compute_x=True, compute_log_det=True, | ||
name=None): | ||
return self._flow.transform( | ||
x=y, compute_y=compute_x, compute_log_det=compute_log_det, | ||
name=name | ||
) | ||
|
||
def _build(self, input=None): | ||
raise RuntimeError('Should never be called.') # pragma: no cover | ||
|
||
def _transform(self, x, compute_y, compute_log_det): | ||
raise RuntimeError('Should never be called.') # pragma: no cover | ||
|
||
def _inverse_transform(self, y, compute_x, compute_log_det): | ||
raise RuntimeError('Should never be called.') # pragma: no cover |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters