-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Expose linear bijectors, shift bijector, and general multivariate nor…
…mal. PiperOrigin-RevId: 438573587
- Loading branch information
Showing
14 changed files
with
285 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Linear bijector.""" | ||
|
||
import abc | ||
from typing import Sequence, Tuple | ||
|
||
from distrax._src.bijectors import bijector as base | ||
import jax.numpy as jnp | ||
|
||
Array = base.Array | ||
|
||
|
||
class Linear(base.Bijector, metaclass=abc.ABCMeta): | ||
"""Base class for linear bijectors. | ||
This class provides a base class for bijectors defined as `f(x) = Ax`, | ||
where `A` is a `DxD` matrix and `x` is a `D`-dimensional vector. | ||
""" | ||
|
||
def __init__(self, | ||
event_dims: int, | ||
batch_shape: Sequence[int], | ||
dtype: jnp.dtype): | ||
"""Initializes a `Linear` bijector. | ||
Args: | ||
event_dims: the dimensionality `D` of the event `x`. It is assumed that | ||
`x` is a vector of length `event_dims`. | ||
batch_shape: the batch shape of the bijector. | ||
dtype: the data type of matrix `A`. | ||
""" | ||
super().__init__(event_ndims_in=1, is_constant_jacobian=True) | ||
self._event_dims = event_dims | ||
self._batch_shape = tuple(batch_shape) | ||
self._dtype = dtype | ||
|
||
@property | ||
def matrix(self) -> Array: | ||
"""The matrix `A` of the transformation. | ||
To be optionally implemented in a subclass. | ||
Returns: | ||
An array of shape `batch_shape + (event_dims, event_dims)` and data type | ||
`dtype`. | ||
""" | ||
raise NotImplementedError( | ||
f"Linear bijector {self.name} does not implement `matrix`.") | ||
|
||
@property | ||
def event_dims(self) -> int: | ||
"""The dimensionality `D` of the event `x`.""" | ||
return self._event_dims | ||
|
||
@property | ||
def batch_shape(self) -> Tuple[int, ...]: | ||
"""The batch shape of the bijector.""" | ||
return self._batch_shape | ||
|
||
@property | ||
def dtype(self) -> jnp.dtype: | ||
"""The data type of matrix `A`.""" | ||
return self._dtype |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for `linear.py`.""" | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
from distrax._src.bijectors import linear | ||
import jax.numpy as jnp | ||
|
||
|
||
class MockLinear(linear.Linear): | ||
|
||
def forward_and_log_det(self, x): | ||
raise Exception | ||
|
||
|
||
class LinearTest(parameterized.TestCase): | ||
|
||
@parameterized.parameters( | ||
{'event_dims': 1, 'batch_shape': (), 'dtype': jnp.float16}, | ||
{'event_dims': 10, 'batch_shape': (2, 3), 'dtype': jnp.float32}) | ||
def test_properties(self, event_dims, batch_shape, dtype): | ||
bij = MockLinear(event_dims, batch_shape, dtype) | ||
self.assertEqual(bij.event_ndims_in, 1) | ||
self.assertEqual(bij.event_ndims_out, 1) | ||
self.assertTrue(bij.is_constant_jacobian) | ||
self.assertTrue(bij.is_constant_log_det) | ||
self.assertEqual(bij.event_dims, event_dims) | ||
self.assertEqual(bij.batch_shape, batch_shape) | ||
self.assertEqual(bij.dtype, dtype) | ||
with self.assertRaises(NotImplementedError): | ||
bij.matrix # pylint: disable=pointless-statement | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
Oops, something went wrong.