Skip to content

Commit

Permalink
Demote TriangularAffine to TriangularLinear.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 434483035
  • Loading branch information
gpapamak authored and DistraxDev committed Mar 14, 2022
1 parent 71bd1c2 commit 31de305
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 125 deletions.
30 changes: 14 additions & 16 deletions distrax/_src/bijectors/lower_upper_triangular_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
"""LU-decomposed affine bijector."""

from distrax._src.bijectors import bijector as base
from distrax._src.bijectors import block
from distrax._src.bijectors import chain
from distrax._src.bijectors import triangular_affine
from distrax._src.bijectors import shift
from distrax._src.bijectors import triangular_linear
from distrax._src.bijectors import unconstrained_affine
import jax.numpy as jnp


Array = base.Array


Expand Down Expand Up @@ -56,7 +58,7 @@ class LowerUpperTriangularAffine(chain.Chain):
"""

def __init__(self, matrix: Array, bias: Array):
"""Initializes a LowerUpperTriangularAffine bijector.
"""Initializes a `LowerUpperTriangularAffine` bijector.
Args:
matrix: a square matrix parameterizing `L` and `U` as described in the
Expand All @@ -65,28 +67,24 @@ class docstring. Can also be a batch of matrices. If `matrix` is the
generally not equal to the product `LU`.
bias: the vector `b` in `LUx + b`. Can also be a batch of vectors.
"""
if matrix.ndim < 2:
raise ValueError(f"`matrix` must have at least 2 dimensions, got"
f" {matrix.ndim}.")
unconstrained_affine.check_affine_parameters(matrix, bias)
self._upper = triangular_linear.TriangularLinear(matrix, is_lower=False)
dim = matrix.shape[-1]
# z = Ux
self._upper_linear = triangular_affine.TriangularAffine(
matrix, bias=jnp.zeros((dim,)), is_lower=False)
# y = Lz + b
lower = jnp.eye(dim) + jnp.tril(matrix, -1) # Replace diagonal with ones.
self._lower_affine = triangular_affine.TriangularAffine(
lower, bias, is_lower=True)
super().__init__([self._lower_affine, self._upper_linear])
self._lower = triangular_linear.TriangularLinear(lower, is_lower=True)
self._shift = block.Block(shift.Shift(bias), 1)
self._bias = bias
super().__init__([self._shift, self._lower, self._upper])

@property
def lower(self) -> Array:
"""The lower triangular matrix `L` with ones in the diagonal."""
return self._lower_affine.matrix
return self._lower.matrix

@property
def upper(self) -> Array:
"""The upper triangular matrix `U`."""
return self._upper_linear.matrix
return self._upper.matrix

@property
def matrix(self) -> Array:
Expand All @@ -96,7 +94,7 @@ def matrix(self) -> Array:
@property
def bias(self) -> Array:
"""The shift `b` of the transformation."""
return self._lower_affine.bias
return self._bias

def same_as(self, other: base.Bijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Triangular affine bijector."""
"""Triangular linear bijector."""

import functools
from typing import Tuple

from distrax._src.bijectors import bijector as base
from distrax._src.bijectors import unconstrained_affine
import jax
import jax.numpy as jnp

Expand All @@ -30,48 +29,50 @@ def _triangular_logdet(matrix: Array) -> Array:
return jnp.sum(jnp.log(jnp.abs(jnp.diag(matrix))))


def _forward_unbatched(x: Array, matrix: Array, bias: Array) -> Array:
return matrix @ x + bias
def _forward_unbatched(x: Array, matrix: Array) -> Array:
return matrix @ x


def _inverse_unbatched(
y: Array, matrix: Array, bias: Array, is_lower: bool) -> Array:
return jax.scipy.linalg.solve_triangular(matrix, y - bias, lower=is_lower)
def _inverse_unbatched(y: Array, matrix: Array, is_lower: bool) -> Array:
return jax.scipy.linalg.solve_triangular(matrix, y, lower=is_lower)


class TriangularAffine(base.Bijector):
"""An affine bijector whose weight matrix is triangular.
class TriangularLinear(base.Bijector):
"""A linear bijector whose weight matrix is triangular.
The bijector is defined as `f(x) = Ax + b` where `A` is a DxD triangular
matrix.
The bijector is defined as `f(x) = Ax` where `A` is a DxD triangular matrix.
The Jacobian determinant can be computed in O(D) as follows:
log|det J(x)| = log|det A| = sum(log|diag(A)|)
The inverse is computed in O(D^2) by solving the triangular system Ax = y - b.
The inverse is computed in O(D^2) by solving the triangular system `Ax = y`.
The bijector is invertible if and only if all diagonal elements of A are
The bijector is invertible if and only if all diagonal elements of `A` are
non-zero. It is the responsibility of the user to make sure that this is the
case; the class will make no attempt to verify that the bijector is
invertible.
"""

def __init__(self, matrix: Array, bias: Array, is_lower: bool = True):
"""Initializes a TriangularAffine bijector.
def __init__(self, matrix: Array, is_lower: bool = True):
"""Initializes a `TriangularLinear` bijector.
Args:
matrix: a square matrix whose triangular part defines `A` in `Ax + b`. Can
also be a batch of matrices. Whether `A` is the lower or upper
triangular part of `matrix` is determined by `is_lower`.
bias: the vector `b` in `Ax + b`. Can also be a batch of vectors.
matrix: a square matrix whose triangular part defines `A`. Can also be a
batch of matrices. Whether `A` is the lower or upper triangular part of
`matrix` is determined by `is_lower`.
is_lower: if True, `A` is set to the lower triangular part of `matrix`. If
False, `A` is set to the upper triangular part of `matrix`.
"""
if matrix.ndim < 2:
raise ValueError(f"`matrix` must have at least 2 dimensions, got"
f" {matrix.ndim}.")
if matrix.shape[-2] != matrix.shape[-1]:
raise ValueError(f"`matrix` must be square; instead, it has shape"
f" {matrix.shape[-2:]}.")
super().__init__(event_ndims_in=1, is_constant_jacobian=True)
self._batch_shape = unconstrained_affine.common_batch_shape(matrix, bias)
self._batch_shape = matrix.shape[:-2]
self._matrix = jnp.tril(matrix) if is_lower else jnp.triu(matrix)
self._bias = bias
self._is_lower = is_lower
triangular_logdet = jnp.vectorize(_triangular_logdet, signature="(m,m)->()")
self._logdet = triangular_logdet(self._matrix)
Expand All @@ -81,11 +82,6 @@ def matrix(self) -> Array:
"""The triangular matrix `A` of the transformation."""
return self._matrix

@property
def bias(self) -> Array:
"""The shift `b` of the transformation."""
return self._bias

@property
def is_lower(self) -> bool:
"""True if `A` is lower triangular, False if upper triangular."""
Expand All @@ -94,8 +90,8 @@ def is_lower(self) -> bool:
def forward(self, x: Array) -> Array:
"""Computes y = f(x)."""
self._check_forward_input_shape(x)
batched = jnp.vectorize(_forward_unbatched, signature="(m),(m,m),(m)->(m)")
return batched(x, self._matrix, self._bias)
batched = jnp.vectorize(_forward_unbatched, signature="(m),(m,m)->(m)")
return batched(x, self._matrix)

def forward_log_det_jacobian(self, x: Array) -> Array:
"""Computes log|det J(f)(x)|."""
Expand All @@ -112,8 +108,8 @@ def inverse(self, y: Array) -> Array:
self._check_inverse_input_shape(y)
batched = jnp.vectorize(
functools.partial(_inverse_unbatched, is_lower=self._is_lower),
signature="(m),(m,m),(m)->(m)")
return batched(y, self._matrix, self._bias)
signature="(m),(m,m)->(m)")
return batched(y, self._matrix)

def inverse_log_det_jacobian(self, y: Array) -> Array:
"""Computes log|det J(f^{-1})(y)|."""
Expand All @@ -125,10 +121,9 @@ def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:

def same_as(self, other: base.Bijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is TriangularAffine: # pylint: disable=unidiomatic-typecheck
if type(other) is TriangularLinear: # pylint: disable=unidiomatic-typecheck
return all((
self.matrix is other.matrix,
self.bias is other.bias,
self.is_lower is other.is_lower,
))
return False
Original file line number Diff line number Diff line change
Expand Up @@ -12,76 +12,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `triangular_affine.py`."""
"""Tests for `triangular_linear.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
from distrax._src.bijectors.tanh import Tanh
from distrax._src.bijectors.triangular_affine import TriangularAffine
from distrax._src.bijectors.triangular_linear import TriangularLinear
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np


class TriangularAffineTest(parameterized.TestCase):
class TriangularLinearTest(parameterized.TestCase):

def test_jacobian_is_constant_property(self):
bijector = TriangularAffine(
matrix=jnp.eye(4), bias=jnp.zeros((4,)))
bijector = TriangularLinear(matrix=jnp.eye(4))
self.assertTrue(bijector.is_constant_jacobian)
self.assertTrue(bijector.is_constant_log_det)

@parameterized.parameters(True, False)
def test_properties(self, is_lower):
bijector = TriangularAffine(
matrix=jnp.ones((4, 4)),
bias=jnp.ones((4,)),
is_lower=is_lower)
bijector = TriangularLinear(matrix=jnp.ones((4, 4)), is_lower=is_lower)
tri = np.tril if is_lower else np.triu
np.testing.assert_allclose(bijector.matrix, tri(np.ones(4)), atol=1e-6)
np.testing.assert_allclose(bijector.bias, np.ones((4,)), atol=1e-6)
self.assertEqual(bijector.is_lower, is_lower)

@parameterized.named_parameters(
('matrix is 0d', {'matrix': np.zeros(()), 'bias': np.zeros((4,))}),
('matrix is 1d', {'matrix': np.zeros((4,)), 'bias': np.zeros((4,))}),
('bias is 0d', {'matrix': np.zeros((4, 4)), 'bias': np.zeros(())}),
('matrix is not square',
{'matrix': np.zeros((3, 4)), 'bias': np.zeros((4,))}),
('matrix and bias shapes do not agree',
{'matrix': np.zeros((4, 4)), 'bias': np.zeros((3,))}),
('matrix is 0d', {'matrix': np.zeros(())}),
('matrix is 1d', {'matrix': np.zeros((4,))}),
('matrix is not square', {'matrix': np.zeros((3, 4))}),
)
def test_raises_with_invalid_parameters(self, bij_params):
with self.assertRaises(ValueError):
TriangularAffine(**bij_params)
TriangularLinear(**bij_params)

@chex.all_variants
@parameterized.parameters(
((5,), (5,), (5,)),
((5,), (5,), ()),
((5,), (), (5,)),
((), (5,), (5,)),
((), (), (5,)),
((), (5,), ()),
((5,), (), ()),
((5,), (5,)),
((5,), ()),
((), (5,)),
)
def test_batched_parameters(self, matrix_batch_shape, bias_batch_shape,
input_batch_shape):
def test_batched_parameters(self, matrix_batch_shape, input_batch_shape):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
matrix = jax.random.uniform(
next(prng), matrix_batch_shape + (4, 4)) + jnp.eye(4)
bias = jax.random.normal(next(prng), bias_batch_shape + (4,))
bijector = TriangularAffine(matrix, bias)
bijector = TriangularLinear(matrix)

x = jax.random.normal(next(prng), input_batch_shape + (4,))
y, logdet_fwd = self.variant(bijector.forward_and_log_det)(x)
z, logdet_inv = self.variant(bijector.inverse_and_log_det)(x)

output_batch_shape = jnp.broadcast_arrays(
matrix[..., 0, 0], bias[..., 0], x[..., 0])[0].shape
matrix[..., 0, 0], x[..., 0])[0].shape

self.assertEqual(y.shape, output_batch_shape + (4,))
self.assertEqual(z.shape, output_batch_shape + (4,))
Expand All @@ -90,33 +75,29 @@ def test_batched_parameters(self, matrix_batch_shape, bias_batch_shape,

matrix = jnp.broadcast_to(
matrix, output_batch_shape + (4, 4)).reshape((-1, 4, 4))
bias = jnp.broadcast_to(bias, output_batch_shape + (4,)).reshape((-1, 4))
x = jnp.broadcast_to(x, output_batch_shape + (4,)).reshape((-1, 4))
y = y.reshape((-1, 4))
z = z.reshape((-1, 4))
logdet_fwd = logdet_fwd.flatten()
logdet_inv = logdet_inv.flatten()

for i in range(np.prod(output_batch_shape)):
bijector = TriangularAffine(matrix[i], bias[i])
bijector = TriangularLinear(matrix[i])
this_y, this_logdet_fwd = self.variant(bijector.forward_and_log_det)(x[i])
this_z, this_logdet_inv = self.variant(bijector.inverse_and_log_det)(x[i])
np.testing.assert_allclose(this_y, y[i], atol=9e-3)
np.testing.assert_allclose(this_y, y[i], rtol=8e-3)
np.testing.assert_allclose(this_z, z[i], atol=7e-6)
np.testing.assert_allclose(this_logdet_fwd, logdet_fwd[i], atol=1e-7)
np.testing.assert_allclose(this_logdet_inv, logdet_inv[i], atol=7e-6)

@chex.all_variants
@parameterized.parameters(
{'batch_shape': (), 'param_shape': (), 'is_lower': True},
{'batch_shape': (3,), 'param_shape': (3,), 'is_lower': True},
{'batch_shape': (2, 3), 'param_shape': (3,), 'is_lower': False},
{'batch_shape': (), 'is_lower': True},
{'batch_shape': (3,), 'is_lower': True},
{'batch_shape': (2, 3), 'is_lower': False},
)
def test_identity_initialization(self, batch_shape, param_shape, is_lower):
bijector = TriangularAffine(
matrix=jnp.eye(4),
bias=jnp.zeros(param_shape + (4,)),
is_lower=is_lower)
def test_identity_initialization(self, batch_shape, is_lower):
bijector = TriangularLinear(matrix=jnp.eye(4), is_lower=is_lower)
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jax.random.normal(next(prng), batch_shape + (4,))

Expand All @@ -139,8 +120,7 @@ def test_identity_initialization(self, batch_shape, param_shape, is_lower):
def test_inverse_methods(self, batch_shape, param_shape, is_lower):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
matrix = jax.random.uniform(next(prng), param_shape + (4, 4)) + jnp.eye(4)
bias = jax.random.normal(next(prng), param_shape + (4,))
bijector = TriangularAffine(matrix, bias, is_lower)
bijector = TriangularLinear(matrix, is_lower)
x = jax.random.normal(next(prng), batch_shape + (4,))
y, logdet_fwd = self.variant(bijector.forward_and_log_det)(x)
x_rec, logdet_inv = self.variant(bijector.inverse_and_log_det)(y)
Expand All @@ -152,8 +132,7 @@ def test_inverse_methods(self, batch_shape, param_shape, is_lower):
def test_forward_jacobian_det(self, is_lower):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
matrix = jax.random.uniform(next(prng), (4, 4)) + jnp.eye(4)
bias = jax.random.normal(next(prng), (4,))
bijector = TriangularAffine(matrix, bias, is_lower)
bijector = TriangularLinear(matrix, is_lower)

batched_x = jax.random.normal(next(prng), (10, 4))
single_x = jax.random.normal(next(prng), (4,))
Expand All @@ -169,8 +148,7 @@ def test_forward_jacobian_det(self, is_lower):
def test_inverse_jacobian_det(self, is_lower):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
matrix = jax.random.uniform(next(prng), (4, 4)) + jnp.eye(4)
bias = jax.random.normal(next(prng), (4,))
bijector = TriangularAffine(matrix, bias, is_lower)
bijector = TriangularLinear(matrix, is_lower)

batched_y = jax.random.normal(next(prng), (10, 4))
single_y = jax.random.normal(next(prng), (4,))
Expand All @@ -182,7 +160,7 @@ def test_inverse_jacobian_det(self, is_lower):
np.testing.assert_allclose(logdet, logdet_numerical, atol=5e-5)

def test_raises_on_invalid_input_shape(self):
bij = TriangularAffine(matrix=jnp.eye(4), bias=jnp.zeros((4,)))
bij = TriangularLinear(matrix=jnp.eye(4))
for fn in [bij.forward, bij.inverse,
bij.forward_log_det_jacobian, bij.inverse_log_det_jacobian,
bij.forward_and_log_det, bij.inverse_and_log_det]:
Expand All @@ -194,17 +172,17 @@ def test_jittable(self):
def f(x, b):
return b.forward(x)

bij = TriangularAffine(matrix=jnp.eye(4), bias=jnp.zeros((4,)))
bij = TriangularLinear(matrix=jnp.eye(4))
x = np.zeros((4,))
f(x, bij)

def test_same_as_itself(self):
bij = TriangularAffine(matrix=jnp.eye(4), bias=jnp.zeros((4,)))
bij = TriangularLinear(matrix=jnp.eye(4))
self.assertTrue(bij.same_as(bij))

def test_not_same_as_others(self):
bij = TriangularAffine(matrix=jnp.eye(4), bias=jnp.zeros((4,)))
other = TriangularAffine(matrix=jnp.ones((4, 4)), bias=jnp.zeros((4,)))
bij = TriangularLinear(matrix=jnp.eye(4))
other = TriangularLinear(matrix=jnp.ones((4, 4)))
self.assertFalse(bij.same_as(other))
self.assertFalse(bij.same_as(Tanh()))

Expand Down
Loading

0 comments on commit 31de305

Please sign in to comment.