Skip to content

Commit

Permalink
Expose linear bijectors, shift bijector, and general multivariate nor…
Browse files Browse the repository at this point in the history
…mal.

PiperOrigin-RevId: 438573587
  • Loading branch information
gpapamak authored and DistraxDev committed Mar 31, 2022
1 parent b511948 commit f400e01
Show file tree
Hide file tree
Showing 14 changed files with 285 additions and 115 deletions.
12 changes: 12 additions & 0 deletions distrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@
from distrax._src.bijectors.bijector import BijectorLike
from distrax._src.bijectors.block import Block
from distrax._src.bijectors.chain import Chain
from distrax._src.bijectors.diag_linear import DiagLinear
from distrax._src.bijectors.diag_plus_low_rank_linear import DiagPlusLowRankLinear
from distrax._src.bijectors.gumbel_cdf import GumbelCDF
from distrax._src.bijectors.inverse import Inverse
from distrax._src.bijectors.lambda_bijector import Lambda
from distrax._src.bijectors.linear import Linear
from distrax._src.bijectors.lower_upper_triangular_affine import LowerUpperTriangularAffine
from distrax._src.bijectors.masked_coupling import MaskedCoupling
from distrax._src.bijectors.rational_quadratic_spline import RationalQuadraticSpline
from distrax._src.bijectors.scalar_affine import ScalarAffine
from distrax._src.bijectors.shift import Shift
from distrax._src.bijectors.sigmoid import Sigmoid
from distrax._src.bijectors.split_coupling import SplitCoupling
from distrax._src.bijectors.tanh import Tanh
from distrax._src.bijectors.triangular_linear import TriangularLinear
from distrax._src.bijectors.unconstrained_affine import UnconstrainedAffine

# Distributions.
Expand All @@ -49,6 +54,7 @@
from distrax._src.distributions.multinomial import Multinomial
from distrax._src.distributions.mvn_diag import MultivariateNormalDiag
from distrax._src.distributions.mvn_diag_plus_low_rank import MultivariateNormalDiagPlusLowRank
from distrax._src.distributions.mvn_from_bijector import MultivariateNormalFromBijector
from distrax._src.distributions.mvn_full_covariance import MultivariateNormalFullCovariance
from distrax._src.distributions.mvn_tri import MultivariateNormalTri
from distrax._src.distributions.normal import Normal
Expand Down Expand Up @@ -84,6 +90,8 @@
"Block",
"Categorical",
"Chain",
"DiagLinear",
"DiagPlusLowRankLinear",
"Distribution",
"DistributionLike",
"EpsilonGreedy",
Expand All @@ -98,6 +106,7 @@
"Inverse",
"Lambda",
"Laplace",
"Linear",
"LogStddevNormal",
"Logistic",
"LowerUpperTriangularAffine",
Expand All @@ -110,6 +119,7 @@
"multiply_no_nan",
"MultivariateNormalDiag",
"MultivariateNormalDiagPlusLowRank",
"MultivariateNormalFromBijector",
"MultivariateNormalFullCovariance",
"MultivariateNormalTri",
"Normal",
Expand All @@ -118,11 +128,13 @@
"RationalQuadraticSpline",
"register_inverse",
"ScalarAffine",
"Shift",
"Sigmoid",
"Softmax",
"SplitCoupling",
"to_tfp",
"Transformed",
"TriangularLinear",
"UnconstrainedAffine",
"Uniform",
)
Expand Down
22 changes: 19 additions & 3 deletions distrax/_src/bijectors/diag_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
# ==============================================================================
"""Diagonal linear bijector."""

from typing import Tuple

from distrax._src.bijectors import bijector as base
from distrax._src.bijectors import block
from distrax._src.bijectors import linear
from distrax._src.bijectors import scalar_affine
import jax.numpy as jnp

Array = base.Array


class DiagLinear(block.Block):
class DiagLinear(linear.Linear):
"""Linear bijector with a diagonal weight matrix.
The bijector is defined as `f(x) = Ax` where `A` is a `DxD` diagonal matrix.
Expand All @@ -47,9 +50,18 @@ def __init__(self, diag: Array):
"""
if diag.ndim < 1:
raise ValueError("`diag` must have at least one dimension.")
bijector = scalar_affine.ScalarAffine(shift=0., scale=diag)
super().__init__(bijector=bijector, ndims=1)
self._bijector = block.Block(
scalar_affine.ScalarAffine(shift=0., scale=diag), ndims=1)
super().__init__(
event_dims=diag.shape[-1],
batch_shape=diag.shape[:-1],
dtype=diag.dtype)
self._diag = diag
self.forward = self._bijector.forward
self.forward_log_det_jacobian = self._bijector.forward_log_det_jacobian
self.inverse = self._bijector.inverse
self.inverse_log_det_jacobian = self._bijector.inverse_log_det_jacobian
self.inverse_and_log_det = self._bijector.inverse_and_log_det

@property
def diag(self) -> Array:
Expand All @@ -61,6 +73,10 @@ def matrix(self) -> Array:
"""The full matrix `A`."""
return jnp.vectorize(jnp.diag, signature="(k)->(k,k)")(self.diag)

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
return self._bijector.forward_and_log_det(x)

def same_as(self, other: base.Bijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is DiagLinear: # pylint: disable=unidiomatic-typecheck
Expand Down
24 changes: 19 additions & 5 deletions distrax/_src/bijectors/diag_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,29 @@

class DiagLinearTest(parameterized.TestCase):

def test_jacobian_is_constant_property(self):
def test_static_properties(self):
bij = DiagLinear(diag=jnp.ones((4,)))
self.assertTrue(bij.is_constant_jacobian)
self.assertTrue(bij.is_constant_log_det)
self.assertEqual(bij.event_ndims_in, 1)
self.assertEqual(bij.event_ndims_out, 1)

def test_properties(self):
bij = DiagLinear(diag=jnp.ones((4,)))
np.testing.assert_allclose(bij.diag, np.ones(4), atol=1e-6)
np.testing.assert_allclose(bij.matrix, np.eye(4), atol=1e-6)
@parameterized.parameters(
{'batch_shape': (), 'dtype': jnp.float16},
{'batch_shape': (2, 3), 'dtype': jnp.float32},
)
def test_properties(self, batch_shape, dtype):
bij = DiagLinear(diag=jnp.ones(batch_shape + (4,), dtype))
self.assertEqual(bij.event_dims, 4)
self.assertEqual(bij.batch_shape, batch_shape)
self.assertEqual(bij.dtype, dtype)
self.assertEqual(bij.diag.shape, batch_shape + (4,))
self.assertEqual(bij.matrix.shape, batch_shape + (4, 4))
self.assertEqual(bij.diag.dtype, dtype)
self.assertEqual(bij.matrix.dtype, dtype)
np.testing.assert_allclose(bij.diag, 1., atol=1e-6)
np.testing.assert_allclose(
bij.matrix, np.tile(np.eye(4), batch_shape + (1, 1)), atol=1e-6)

def test_raises_with_invalid_parameters(self):
with self.assertRaises(ValueError):
Expand Down
20 changes: 18 additions & 2 deletions distrax/_src/bijectors/diag_plus_low_rank_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from distrax._src.bijectors import bijector as base
from distrax._src.bijectors import chain
from distrax._src.bijectors import diag_linear
from distrax._src.bijectors import linear
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -144,7 +145,7 @@ def _check_shapes_are_valid(diag: Array,
f"`u_matrix.shape = {u_shape}` and `v_matrix.shape = {v_shape}`.")


class DiagPlusLowRankLinear(chain.Chain):
class DiagPlusLowRankLinear(linear.Linear):
"""Linear bijector whose weights are a low-rank perturbation of a diagonal.
The bijector is defined as `f(x) = Ax` where `A = S + UV^T` and:
Expand Down Expand Up @@ -190,10 +191,21 @@ def __init__(self, diag: Array, u_matrix: Array, v_matrix: Array):
id_plus_low_rank_linear = _IdentityPlusLowRankLinear(
u_matrix=u_matrix / diag[..., None],
v_matrix=v_matrix)
super().__init__([diag_linear.DiagLinear(diag), id_plus_low_rank_linear])
self._bijector = chain.Chain(
[diag_linear.DiagLinear(diag), id_plus_low_rank_linear])
batch_shape = jnp.broadcast_shapes(
diag.shape[:-1], u_matrix.shape[:-2], v_matrix.shape[:-2])
dtype = jnp.result_type(diag, u_matrix, v_matrix)
super().__init__(
event_dims=diag.shape[-1], batch_shape=batch_shape, dtype=dtype)
self._diag = diag
self._u_matrix = u_matrix
self._v_matrix = v_matrix
self.forward = self._bijector.forward
self.forward_log_det_jacobian = self._bijector.forward_log_det_jacobian
self.inverse = self._bijector.inverse
self.inverse_log_det_jacobian = self._bijector.inverse_log_det_jacobian
self.inverse_and_log_det = self._bijector.inverse_and_log_det

@property
def diag(self) -> Array:
Expand All @@ -218,6 +230,10 @@ def matrix(self) -> Array:
signature="(d),(d,k),(d,k)->(d,d)")
return batched(self._diag, self._u_matrix, self._v_matrix)

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
return self._bijector.forward_and_log_det(x)

def same_as(self, other: base.Bijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is DiagPlusLowRankLinear: # pylint: disable=unidiomatic-typecheck
Expand Down
35 changes: 26 additions & 9 deletions distrax/_src/bijectors/diag_plus_low_rank_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,41 @@

class DiagPlusLowRankLinearTest(parameterized.TestCase):

def test_jacobian_is_constant_property(self):
def test_static_properties(self):
bij = DiagPlusLowRankLinear(
diag=jnp.ones((4,)),
u_matrix=jnp.ones((4, 2)),
v_matrix=jnp.ones((4, 2)))
self.assertTrue(bij.is_constant_jacobian)
self.assertTrue(bij.is_constant_log_det)
self.assertEqual(bij.event_ndims_in, 1)
self.assertEqual(bij.event_ndims_out, 1)

def test_properties(self):
@parameterized.parameters(
{'batch_shape': (), 'dtype': jnp.float16},
{'batch_shape': (2, 3), 'dtype': jnp.float32},
)
def test_properties(self, batch_shape, dtype):
bij = DiagPlusLowRankLinear(
diag=jnp.ones((4,)),
u_matrix=2. * jnp.ones((4, 2)),
v_matrix=3. * jnp.ones((4, 2)))
np.testing.assert_allclose(bij.diag, np.ones(4), atol=1e-6)
np.testing.assert_allclose(bij.u_matrix, np.full((4, 2), 2.), atol=1e-6)
np.testing.assert_allclose(bij.v_matrix, np.full((4, 2), 3.), atol=1e-6)
diag=jnp.ones(batch_shape + (4,), dtype),
u_matrix=2. * jnp.ones(batch_shape + (4, 2), dtype),
v_matrix=3. * jnp.ones(batch_shape + (4, 2), dtype))
self.assertEqual(bij.event_dims, 4)
self.assertEqual(bij.batch_shape, batch_shape)
self.assertEqual(bij.dtype, dtype)
self.assertEqual(bij.diag.shape, batch_shape + (4,))
self.assertEqual(bij.u_matrix.shape, batch_shape + (4, 2))
self.assertEqual(bij.v_matrix.shape, batch_shape + (4, 2))
self.assertEqual(bij.matrix.shape, batch_shape + (4, 4))
self.assertEqual(bij.diag.dtype, dtype)
self.assertEqual(bij.u_matrix.dtype, dtype)
self.assertEqual(bij.v_matrix.dtype, dtype)
self.assertEqual(bij.matrix.dtype, dtype)
np.testing.assert_allclose(bij.diag, 1., atol=1e-6)
np.testing.assert_allclose(bij.u_matrix, 2., atol=1e-6)
np.testing.assert_allclose(bij.v_matrix, 3., atol=1e-6)
np.testing.assert_allclose(
bij.matrix, np.eye(4) + np.full((4, 4), 12.), atol=1e-6)
bij.matrix, np.tile(np.eye(4) + 12., batch_shape + (1, 1)), atol=1e-6)

@parameterized.named_parameters(
('diag is 0d', {'diag': np.ones(()),
Expand Down
76 changes: 76 additions & 0 deletions distrax/_src/bijectors/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linear bijector."""

import abc
from typing import Sequence, Tuple

from distrax._src.bijectors import bijector as base
import jax.numpy as jnp

Array = base.Array


class Linear(base.Bijector, metaclass=abc.ABCMeta):
"""Base class for linear bijectors.
This class provides a base class for bijectors defined as `f(x) = Ax`,
where `A` is a `DxD` matrix and `x` is a `D`-dimensional vector.
"""

def __init__(self,
event_dims: int,
batch_shape: Sequence[int],
dtype: jnp.dtype):
"""Initializes a `Linear` bijector.
Args:
event_dims: the dimensionality `D` of the event `x`. It is assumed that
`x` is a vector of length `event_dims`.
batch_shape: the batch shape of the bijector.
dtype: the data type of matrix `A`.
"""
super().__init__(event_ndims_in=1, is_constant_jacobian=True)
self._event_dims = event_dims
self._batch_shape = tuple(batch_shape)
self._dtype = dtype

@property
def matrix(self) -> Array:
"""The matrix `A` of the transformation.
To be optionally implemented in a subclass.
Returns:
An array of shape `batch_shape + (event_dims, event_dims)` and data type
`dtype`.
"""
raise NotImplementedError(
f"Linear bijector {self.name} does not implement `matrix`.")

@property
def event_dims(self) -> int:
"""The dimensionality `D` of the event `x`."""
return self._event_dims

@property
def batch_shape(self) -> Tuple[int, ...]:
"""The batch shape of the bijector."""
return self._batch_shape

@property
def dtype(self) -> jnp.dtype:
"""The data type of matrix `A`."""
return self._dtype
48 changes: 48 additions & 0 deletions distrax/_src/bijectors/linear_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `linear.py`."""

from absl.testing import absltest
from absl.testing import parameterized
from distrax._src.bijectors import linear
import jax.numpy as jnp


class MockLinear(linear.Linear):

def forward_and_log_det(self, x):
raise Exception


class LinearTest(parameterized.TestCase):

@parameterized.parameters(
{'event_dims': 1, 'batch_shape': (), 'dtype': jnp.float16},
{'event_dims': 10, 'batch_shape': (2, 3), 'dtype': jnp.float32})
def test_properties(self, event_dims, batch_shape, dtype):
bij = MockLinear(event_dims, batch_shape, dtype)
self.assertEqual(bij.event_ndims_in, 1)
self.assertEqual(bij.event_ndims_out, 1)
self.assertTrue(bij.is_constant_jacobian)
self.assertTrue(bij.is_constant_log_det)
self.assertEqual(bij.event_dims, event_dims)
self.assertEqual(bij.batch_shape, batch_shape)
self.assertEqual(bij.dtype, dtype)
with self.assertRaises(NotImplementedError):
bij.matrix # pylint: disable=pointless-statement


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit f400e01

Please sign in to comment.