Skip to content

Commit

Permalink
added InvertFlow
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Jan 15, 2019
1 parent 7ed63e5 commit 0fc46b7
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 6 deletions.
68 changes: 68 additions & 0 deletions tests/layers/flows/test_invert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import pytest
import tensorflow as tf

from tests.layers.flows.helper import (QuadraticFlow, quadratic_transform,
npyops, invertible_flow_standard_check)
from tfsnippet import FlowDistribution, Normal
from tfsnippet.layers import InvertFlow, PlanarNormalizingFlow, BaseFlow


class InvertFlowTestCase(tf.test.TestCase):

def test_invert_flow(self):
with self.test_session() as sess:
# test invert a normal flow
flow = QuadraticFlow(2., 5.)
inv_flow = flow.invert()

self.assertIsInstance(inv_flow, InvertFlow)
self.assertEqual(inv_flow.x_value_ndims, 0)
self.assertEqual(inv_flow.y_value_ndims, 0)
self.assertFalse(inv_flow.require_batch_dims)

test_x = np.arange(12, dtype=np.float32) + 1.
test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.)

self.assertFalse(flow._has_built)
y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x))
self.assertTrue(flow._has_built)

np.testing.assert_allclose(sess.run(y), test_y)
np.testing.assert_allclose(sess.run(log_det_y), test_log_det)
invertible_flow_standard_check(self, inv_flow, sess, test_y)

# test invert an InvertFlow
inv_inv_flow = inv_flow.invert()
self.assertIs(inv_inv_flow, flow)

# test use with FlowDistribution
normal = Normal(mean=1., std=2.)
inv_flow = QuadraticFlow(2., 5.).invert()
distrib = FlowDistribution(normal, inv_flow)
distrib_log_det = distrib.log_prob(test_x)
np.testing.assert_allclose(
*sess.run([distrib_log_det,
normal.log_prob(test_y) + test_log_det])
)

def test_property(self):
class _Flow(BaseFlow):
@property
def explicitly_invertible(self):
return True

inv_flow = _Flow(x_value_ndims=2, y_value_ndims=3,
require_batch_dims=True).invert()
self.assertTrue(inv_flow.require_batch_dims)
self.assertEqual(inv_flow.x_value_ndims, 3)
self.assertEqual(inv_flow.y_value_ndims, 2)

def test_errors(self):
with pytest.raises(ValueError, match='`flow` must be an explicitly '
'invertible flow'):
_ = InvertFlow(object())

with pytest.raises(ValueError, match='`flow` must be an explicitly '
'invertible flow'):
_ = PlanarNormalizingFlow().invert()
7 changes: 4 additions & 3 deletions tfsnippet/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

__all__ = [
'ActNorm', 'BaseFlow', 'BaseLayer', 'CouplingLayer', 'FeatureMappingFlow',
'FeatureShufflingFlow', 'InvertibleConv2d', 'InvertibleDense',
'MultiLayerFlow', 'PlanarNormalizingFlow', 'SequentialFlow', 'act_norm',
'avg_pool2d', 'broadcast_log_det_against_input', 'conv2d', 'deconv2d',
'FeatureShufflingFlow', 'InvertFlow', 'InvertibleConv2d',
'InvertibleDense', 'MultiLayerFlow', 'PlanarNormalizingFlow',
'SequentialFlow', 'act_norm', 'avg_pool2d',
'broadcast_log_det_against_input', 'conv2d', 'deconv2d',
'default_kernel_initializer', 'dense', 'global_avg_pool2d',
'l2_regularizer', 'max_pool2d', 'planar_normalizing_flows',
'resnet_conv2d_block', 'resnet_deconv2d_block', 'resnet_general_block',
Expand Down
3 changes: 2 additions & 1 deletion tfsnippet/layers/flows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import *
from .coupling import *
from .invert import *
from .linear import *
from .planar_nf import *
from .rearrangement import *
Expand All @@ -8,7 +9,7 @@

__all__ = [
'BaseFlow', 'CouplingLayer', 'FeatureMappingFlow', 'FeatureShufflingFlow',
'InvertibleConv2d', 'InvertibleDense', 'MultiLayerFlow',
'InvertFlow', 'InvertibleConv2d', 'InvertibleDense', 'MultiLayerFlow',
'PlanarNormalizingFlow', 'SequentialFlow',
'broadcast_log_det_against_input', 'planar_normalizing_flows',
]
17 changes: 17 additions & 0 deletions tfsnippet/layers/flows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ def __init__(self,
self._x_input_spec = None # type: InputSpec
self._y_input_spec = None # type: InputSpec

def invert(self):
"""
Get the inverted flow from this flow.
The :meth:`transform()` will become the :meth:`inverse_transform()`
in the inverted flow, and the :meth:`inverse_transform()` will become
the :meth:`transform()` in the inverted flow.
If the current flow has not been initialized, it must be initialized
via :meth:`inverse_transform()` in the new flow.
Returns:
tfsnippet.layers.InvertFlow: The inverted flow.
"""
from .invert import InvertFlow
return InvertFlow(self)

@property
def x_value_ndims(self):
"""
Expand Down
74 changes: 74 additions & 0 deletions tfsnippet/layers/flows/invert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from .base import BaseFlow

__all__ = ['InvertFlow']


class InvertFlow(BaseFlow):
"""
Turn a :class:`BaseFlow` into its inverted flow.
This class is particularly useful when the flow is (theoretically) defined
in the opposite direction to the direction of network initialization.
For example, define `z -> x`, but initialized by feeding `x`.
"""

def __init__(self, flow, name=None, scope=None):
"""
Construct a new :class:`InvertFlow`.
Args:
flow (BaseFlow): The underlying flow.
"""
if not isinstance(flow, BaseFlow) or not flow.explicitly_invertible:
raise ValueError('`flow` must be an explicitly invertible flow: '
'got {!r}'.format(flow))
self._flow = flow

super(InvertFlow, self).__init__(
x_value_ndims=flow.y_value_ndims,
y_value_ndims=flow.x_value_ndims,
require_batch_dims=flow.require_batch_dims,
name=name,
scope=scope
)

def invert(self):
"""
Get the original flow, inverted by this :class:`InvertFlow`.
Returns:
BaseFlow: The original flow.
"""
return self._flow

@property
def explicitly_invertible(self):
return True

def build(self, input=None): # pragma: no cover
# since `flow` should be inverted, we should build `flow` in
# `inverse_transform` rather than in `transform` or `build`
pass

def transform(self, x, compute_y=True, compute_log_det=True, name=None):
return self._flow.inverse_transform(
y=x, compute_x=compute_y, compute_log_det=compute_log_det,
name=name
)

def inverse_transform(self, y, compute_x=True, compute_log_det=True,
name=None):
return self._flow.transform(
x=y, compute_y=compute_x, compute_log_det=compute_log_det,
name=name
)

def _build(self, input=None):
raise RuntimeError('Should never be called.') # pragma: no cover

def _transform(self, x, compute_y, compute_log_det):
raise RuntimeError('Should never be called.') # pragma: no cover

def _inverse_transform(self, y, compute_x, compute_log_det):
raise RuntimeError('Should never be called.') # pragma: no cover
4 changes: 2 additions & 2 deletions tfsnippet/layers/normalization/act_norm_.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class ActNorm(FeatureMappingFlow):
`bias` and `scale` are initialized such that `y` will have zero mean and
unit variance for the initial mini-batch of `x`.
It can be initialized only through the forward pass. You may need to use
:meth:`invert()` to get a inverted flow if you need to initialize the
parameters via the opposite direction.
:meth:`BaseFlow.invert()` to get a inverted flow if you need to initialize
the parameters via the opposite direction.
"""

_build_require_input = True
Expand Down

0 comments on commit 0fc46b7

Please sign in to comment.