Skip to content

Commit

Permalink
Add a Shift bijector.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 434468445
  • Loading branch information
gpapamak authored and DistraxDev committed Mar 14, 2022
1 parent 36ab44c commit 71bd1c2
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 25 deletions.
84 changes: 84 additions & 0 deletions distrax/_src/bijectors/shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shift bijector."""

from typing import Tuple, Union

from distrax._src.bijectors import bijector as base
import jax
import jax.numpy as jnp

Array = base.Array
Numeric = Union[Array, float]


class Shift(base.Bijector):
"""Bijector that translates its input elementwise.
The bijector is defined as follows:
- Forward: `y = x + shift`
- Forward Jacobian determinant: `log|det J(x)| = 0`
- Inverse: `x = y - shift`
- Inverse Jacobian determinant: `log|det J(y)| = 0`
where `shift` parameterizes the bijector.
"""

def __init__(self, shift: Numeric):
"""Initializes a `Shift` bijector.
Args:
shift: the bijector's shift parameter. Can also be batched.
"""
super().__init__(event_ndims_in=0, is_constant_jacobian=True)
self._shift = shift
self._batch_shape = jnp.shape(self._shift)

@property
def shift(self) -> Numeric:
"""The bijector's shift."""
return self._shift

def forward(self, x: Array) -> Array:
"""Computes y = f(x)."""
return x + self._shift

def forward_log_det_jacobian(self, x: Array) -> Array:
"""Computes log|det J(f)(x)|."""
batch_shape = jax.lax.broadcast_shapes(self._batch_shape, x.shape)
return jnp.zeros(batch_shape, dtype=x.dtype)

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
return self.forward(x), self.forward_log_det_jacobian(x)

def inverse(self, y: Array) -> Array:
"""Computes x = f^{-1}(y)."""
return y - self._shift

def inverse_log_det_jacobian(self, y: Array) -> Array:
"""Computes log|det J(f^{-1})(y)|."""
return self.forward_log_det_jacobian(y)

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
return self.inverse(y), self.inverse_log_det_jacobian(y)

def same_as(self, other: base.Bijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is Shift: # pylint: disable=unidiomatic-typecheck
return self.shift is other.shift
return False
108 changes: 108 additions & 0 deletions distrax/_src/bijectors/shift_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `shift.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
from distrax._src.bijectors.shift import Shift
from distrax._src.bijectors.tanh import Tanh
import jax
import jax.numpy as jnp
import numpy as np


class ShiftTest(parameterized.TestCase):

def test_jacobian_is_constant_property(self):
bijector = Shift(jnp.ones((4,)))
self.assertTrue(bijector.is_constant_jacobian)
self.assertTrue(bijector.is_constant_log_det)

def test_properties(self):
bijector = Shift(jnp.array([1., 2., 3.]))
np.testing.assert_array_equal(bijector.shift, np.array([1., 2., 3.]))

@chex.all_variants
@parameterized.parameters(
{'batch_shape': (), 'param_shape': ()},
{'batch_shape': (3,), 'param_shape': ()},
{'batch_shape': (), 'param_shape': (3,)},
{'batch_shape': (2, 3), 'param_shape': (2, 3)},
)
def test_forward_methods(self, batch_shape, param_shape):
bijector = Shift(jnp.ones(param_shape))
prng = jax.random.PRNGKey(42)
x = jax.random.normal(prng, batch_shape)
output_shape = jnp.broadcast_shapes(batch_shape, param_shape)
y1 = self.variant(bijector.forward)(x)
logdet1 = self.variant(bijector.forward_log_det_jacobian)(x)
y2, logdet2 = self.variant(bijector.forward_and_log_det)(x)
self.assertEqual(y1.shape, output_shape)
self.assertEqual(y2.shape, output_shape)
self.assertEqual(logdet1.shape, output_shape)
self.assertEqual(logdet2.shape, output_shape)
np.testing.assert_allclose(y1, x + 1., 1e-6)
np.testing.assert_allclose(y2, x + 1., 1e-6)
np.testing.assert_allclose(logdet1, 0., 1e-6)
np.testing.assert_allclose(logdet2, 0., 1e-6)

@chex.all_variants
@parameterized.parameters(
{'batch_shape': (), 'param_shape': ()},
{'batch_shape': (3,), 'param_shape': ()},
{'batch_shape': (), 'param_shape': (3,)},
{'batch_shape': (2, 3), 'param_shape': (2, 3)},
)
def test_inverse_methods(self, batch_shape, param_shape):
bijector = Shift(jnp.ones(param_shape))
prng = jax.random.PRNGKey(42)
y = jax.random.normal(prng, batch_shape)
output_shape = jnp.broadcast_shapes(batch_shape, param_shape)
x1 = self.variant(bijector.inverse)(y)
logdet1 = self.variant(bijector.inverse_log_det_jacobian)(y)
x2, logdet2 = self.variant(bijector.inverse_and_log_det)(y)
self.assertEqual(x1.shape, output_shape)
self.assertEqual(x2.shape, output_shape)
self.assertEqual(logdet1.shape, output_shape)
self.assertEqual(logdet2.shape, output_shape)
np.testing.assert_allclose(x1, y - 1., 1e-6)
np.testing.assert_allclose(x2, y - 1., 1e-6)
np.testing.assert_allclose(logdet1, 0., 1e-6)
np.testing.assert_allclose(logdet2, 0., 1e-6)

def test_jittable(self):
@jax.jit
def f(x, b):
return b.forward(x)

bij = Shift(jnp.ones((4,)))
x = np.zeros((4,))
f(x, bij)

def test_same_as_itself(self):
bij = Shift(jnp.ones((4,)))
self.assertTrue(bij.same_as(bij))

def test_not_same_as_others(self):
bij = Shift(jnp.ones((4,)))
other = Shift(jnp.zeros((4,)))
self.assertFalse(bij.same_as(other))
self.assertFalse(bij.same_as(Tanh()))


if __name__ == '__main__':
absltest.main()
27 changes: 2 additions & 25 deletions distrax/_src/distributions/mvn_from_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distrax._src.bijectors import block
from distrax._src.bijectors import chain
from distrax._src.bijectors import diag_affine
from distrax._src.bijectors import shift
from distrax._src.distributions import independent
from distrax._src.distributions import normal
from distrax._src.distributions import transformed
Expand All @@ -35,30 +36,6 @@
tfd = tfp.distributions

Array = chex.Array
Numeric = Union[Array, float]


class _Shift(base_bijector.Bijector):
"""An affine bijector that shifts the input scalar."""

def __init__(self, shift: Numeric):
super().__init__(
event_ndims_in=0, event_ndims_out=0, is_constant_jacobian=True)
self._shift = shift
self._batch_shape = jnp.shape(shift)

def _broadcast_log_det_jac(self, x: Array) -> Array:
"""Returns log|det J(f)(x)| with the shape broadcasted to `x`."""
batch_shape = jax.lax.broadcast_shapes(self._batch_shape, x.shape)
return jnp.zeros(batch_shape, dtype=x.dtype)

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
return x + self._shift, self._broadcast_log_det_jac(x)

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
return y - self._shift, self._broadcast_log_det_jac(y)


def _check_input_parameters_are_valid(
Expand Down Expand Up @@ -126,7 +103,7 @@ def __init__(self,
scale=1.),
reinterpreted_batch_ndims=1)
# Form the bijector `f(x) = Ax + b`.
bijector = chain.Chain([block.Block(_Shift(loc), ndims=1), scale])
bijector = chain.Chain([block.Block(shift.Shift(loc), ndims=1), scale])
super().__init__(distribution=std_mvn_dist, bijector=bijector)
self._scale = scale
self._loc = loc
Expand Down

0 comments on commit 71bd1c2

Please sign in to comment.