Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demote DiagAffine to DiagLinear. #127

Merged
merged 1 commit into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Diagonal affine bijector."""
"""Diagonal linear bijector."""

from distrax._src.bijectors import bijector as base
from distrax._src.bijectors import block
Expand All @@ -22,31 +22,11 @@
Array = base.Array


def _check_shapes_are_valid(diag: Array, bias: Array):
"""Checks array shapes are valid, raises `ValueError` if not."""
if diag.ndim < 1:
raise ValueError("`diag` must have at least one dimension.")
if bias.ndim < 1:
raise ValueError("`bias` must have at least one dimension.")
if bias.shape[-1] != diag.shape[-1]:
raise ValueError(
f"Both `bias` and `diag` must have the same number of dimensions; got "
f"`bias.shape[-1]={bias.shape[-1]}` and "
f"`diag.shape[-1]={diag.shape[-1]}`.")
try:
jnp.broadcast_shapes(diag.shape, bias.shape)
except ValueError:
raise ValueError(
f"The shapes of `bias` and `diag` are not broadcastable; got "
f"`bias.shape={bias.shape}` and `diag.shape={diag.shape}`.") from None
class DiagLinear(block.Block):
"""Linear bijector with a diagonal weight matrix.


class DiagAffine(block.Block):
"""Affine bijector with a diagonal weight matrix.

The bijector is defined as `f(x) = Ax + b` where `A` is a `DxD` diagonal
matrix and `b` is a `D`-dimensional vector. Additional dimensions, if any,
index batches.
The bijector is defined as `f(x) = Ax` where `A` is a `DxD` diagonal matrix.
Additional dimensions, if any, index batches.

The Jacobian determinant is trivially computed by taking the product of the
diagonal entries in `A`. The inverse transformation `x = f^{-1}(y)` is
Expand All @@ -58,40 +38,31 @@ class DiagAffine(block.Block):
invertible.
"""

def __init__(self, diag: Array, bias: Array):
def __init__(self, diag: Array):
"""Initializes the bijector.

Args:
diag: a vector of length D, the diagonal of matrix `A`. Can also be a
batch of such vectors.
bias: the vector `b` in `Ax + b`. Can also be a batch of such vectors.
"""
_check_shapes_are_valid(diag=diag, bias=bias)
bijector = scalar_affine.ScalarAffine(shift=bias, scale=diag)
if diag.ndim < 1:
raise ValueError("`diag` must have at least one dimension.")
bijector = scalar_affine.ScalarAffine(shift=0., scale=diag)
super().__init__(bijector=bijector, ndims=1)
self._diag = diag
self._bias = bias

@property
def diag(self) -> Array:
"""Vector of length D, the diagonal of matrix `A`."""
return self._diag

@property
def bias(self) -> Array:
"""The bias `b` of the transformation."""
return self._bias

@property
def matrix(self) -> Array:
"""The full matrix `A`."""
return jnp.vectorize(jnp.diag, signature="(k)->(k,k)")(self.diag)

def same_as(self, other: base.Bijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is DiagAffine: # pylint: disable=unidiomatic-typecheck
return all((
self.diag is other.diag,
self.bias is other.bias,
))
if type(other) is DiagLinear: # pylint: disable=unidiomatic-typecheck
return self.diag is other.diag
return False
Original file line number Diff line number Diff line change
Expand Up @@ -12,83 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `diag_affine.py`."""
"""Tests for `diag_linear.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
from distrax._src.bijectors.diag_affine import DiagAffine
from distrax._src.bijectors.diag_linear import DiagLinear
from distrax._src.bijectors.tanh import Tanh
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np


class DiagAffineTest(parameterized.TestCase):
class DiagLinearTest(parameterized.TestCase):

def test_jacobian_is_constant_property(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
self.assertTrue(bij.is_constant_jacobian)
self.assertTrue(bij.is_constant_log_det)

def test_properties(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
np.testing.assert_allclose(bij.diag, np.ones(4), atol=1e-6)
np.testing.assert_allclose(bij.bias, np.zeros((4,)), atol=1e-6)
np.testing.assert_allclose(bij.matrix, np.eye(4), atol=1e-6)

@parameterized.named_parameters(
('diag is 0d', {'diag': np.ones(()),
'bias': np.zeros((4,))}),
('bias is 0d', {'diag': np.ones((4,)),
'bias': np.zeros(())}),
('inconsistent dim', {'diag': np.ones((3,)),
'bias': np.zeros((4,))}),
('not broadcastable', {'diag': np.ones((3, 4)),
'bias': np.zeros((2, 4))}),
)
def test_raises_with_invalid_parameters(self, params):
def test_raises_with_invalid_parameters(self):
with self.assertRaises(ValueError):
DiagAffine(**params)
DiagLinear(diag=np.ones(()))

@chex.all_variants
@parameterized.parameters(
((5,), (5,), (5,)),
((5,), (), ()),
((), (5,), ()),
((), (), (5,)),
((5,), (5,)),
((5,), ()),
((), (5,)),
)
def test_batched_parameters(self, diag_batch_shape, bias_batch_shape,
input_batch_shape):
def test_batched_parameters(self, diag_batch_shape, input_batch_shape):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
diag = jax.random.uniform(next(prng), diag_batch_shape + (4,)) + 0.5
bias = jax.random.normal(next(prng), bias_batch_shape + (4,))
bij = DiagAffine(diag, bias)
bij = DiagLinear(diag)

x = jax.random.normal(next(prng), input_batch_shape + (4,))
y, logdet_fwd = self.variant(bij.forward_and_log_det)(x)
z, logdet_inv = self.variant(bij.inverse_and_log_det)(x)

output_batch_shape = jnp.broadcast_shapes(
diag_batch_shape, bias_batch_shape, input_batch_shape)
diag_batch_shape, input_batch_shape)

self.assertEqual(y.shape, output_batch_shape + (4,))
self.assertEqual(z.shape, output_batch_shape + (4,))
self.assertEqual(logdet_fwd.shape, output_batch_shape)
self.assertEqual(logdet_inv.shape, output_batch_shape)

diag = jnp.broadcast_to(diag, output_batch_shape + (4,)).reshape((-1, 4))
bias = jnp.broadcast_to(bias, output_batch_shape + (4,)).reshape((-1, 4))
x = jnp.broadcast_to(x, output_batch_shape + (4,)).reshape((-1, 4))
y = y.reshape((-1, 4))
z = z.reshape((-1, 4))
logdet_fwd = logdet_fwd.flatten()
logdet_inv = logdet_inv.flatten()

for i in range(np.prod(output_batch_shape)):
bij = DiagAffine(diag=diag[i], bias=bias[i])
bij = DiagLinear(diag=diag[i])
this_y, this_logdet_fwd = self.variant(bij.forward_and_log_det)(x[i])
this_z, this_logdet_inv = self.variant(bij.inverse_and_log_det)(x[i])
np.testing.assert_allclose(this_y, y[i], atol=1e-6)
Expand All @@ -102,9 +87,7 @@ def test_batched_parameters(self, diag_batch_shape, bias_batch_shape,
{'batch_shape': (2, 3), 'param_shape': (3,)},
)
def test_identity_initialization(self, batch_shape, param_shape):
bij = DiagAffine(
diag=jnp.ones(param_shape + (4,)),
bias=jnp.zeros(param_shape + (4,)))
bij = DiagLinear(diag=jnp.ones(param_shape + (4,)))
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jax.random.normal(next(prng), batch_shape + (4,))

Expand All @@ -126,8 +109,7 @@ def test_identity_initialization(self, batch_shape, param_shape):
def test_inverse_methods(self, batch_shape, param_shape):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
diag = jax.random.uniform(next(prng), param_shape + (4,)) + 0.5
bias = jax.random.normal(next(prng), param_shape + (4,))
bij = DiagAffine(diag, bias)
bij = DiagLinear(diag)
x = jax.random.normal(next(prng), batch_shape + (4,))
y, logdet_fwd = self.variant(bij.forward_and_log_det)(x)
x_rec, logdet_inv = self.variant(bij.inverse_and_log_det)(y)
Expand All @@ -138,8 +120,7 @@ def test_inverse_methods(self, batch_shape, param_shape):
def test_forward_jacobian_det(self):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
diag = jax.random.uniform(next(prng), (4,)) + 0.5
bias = jax.random.normal(next(prng), (4,))
bij = DiagAffine(diag, bias)
bij = DiagLinear(diag)

batched_x = jax.random.normal(next(prng), (10, 4))
single_x = jax.random.normal(next(prng), (4,))
Expand All @@ -154,8 +135,7 @@ def test_forward_jacobian_det(self):
def test_inverse_jacobian_det(self):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
diag = jax.random.uniform(next(prng), (4,)) + 0.5
bias = jax.random.normal(next(prng), (4,))
bij = DiagAffine(diag, bias)
bij = DiagLinear(diag)

batched_y = jax.random.normal(next(prng), (10, 4))
single_y = jax.random.normal(next(prng), (4,))
Expand All @@ -167,7 +147,7 @@ def test_inverse_jacobian_det(self):
np.testing.assert_allclose(logdet, logdet_numerical, atol=5e-4)

def test_raises_on_invalid_input_shape(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
for fn_name, fn in [
('forward', bij.forward),
('inverse', bij.inverse),
Expand All @@ -185,17 +165,17 @@ def test_jittable(self):
def f(x, b):
return b.forward(x)

bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
x = np.zeros((4,))
f(x, bij)

def test_same_as_itself(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
self.assertTrue(bij.same_as(bij))

def test_not_same_as_others(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
other = DiagAffine(diag=jnp.ones((4,)), bias=jnp.ones((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
other = DiagLinear(diag=2. * jnp.ones((4,)))
self.assertFalse(bij.same_as(other))
self.assertFalse(bij.same_as(Tanh()))

Expand Down
4 changes: 2 additions & 2 deletions distrax/_src/distributions/mvn_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional

import chex
from distrax._src.bijectors.diag_affine import DiagAffine
from distrax._src.bijectors.diag_linear import DiagLinear
from distrax._src.distributions import distribution
from distrax._src.distributions.mvn_from_bijector import MultivariateNormalFromBijector
from distrax._src.utils import conversion
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(self,
bias = jnp.zeros_like(loc, shape=loc.shape[-1:])
bias = jnp.expand_dims(
bias, axis=list(range(len(broadcasted_shapes) - bias.ndim)))
scale = DiagAffine(bias=jnp.zeros_like(loc), diag=scale_diag)
scale = DiagLinear(scale_diag)
super().__init__(
loc=loc, scale=scale, batch_shape=broadcasted_shapes[:-1],
dtype=jnp.result_type(loc, scale_diag))
Expand Down
6 changes: 2 additions & 4 deletions distrax/_src/distributions/mvn_diag_plus_low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional

import chex
from distrax._src.bijectors.diag_affine import DiagAffine
from distrax._src.bijectors.diag_linear import DiagLinear
from distrax._src.bijectors.diag_plus_low_rank_affine import DiagPlusLowRankAffine
from distrax._src.distributions import distribution
from distrax._src.distributions.mvn_from_bijector import MultivariateNormalFromBijector
Expand Down Expand Up @@ -170,9 +170,7 @@ def __init__(self,

if scale_u_matrix is None:
# The scale matrix is diagonal.
scale = DiagAffine(
diag=self._scale_diag,
bias=jnp.zeros(loc.shape[-1:], dtype=dtype))
scale = DiagLinear(self._scale_diag)
else:
scale = DiagPlusLowRankAffine(
bias=jnp.zeros(loc.shape[-1:], dtype=dtype),
Expand Down
12 changes: 6 additions & 6 deletions distrax/_src/distributions/mvn_from_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from distrax._src.bijectors import bijector as base_bijector
from distrax._src.bijectors import block
from distrax._src.bijectors import chain
from distrax._src.bijectors import diag_affine
from distrax._src.bijectors import diag_linear
from distrax._src.bijectors import shift
from distrax._src.distributions import independent
from distrax._src.distributions import normal
Expand Down Expand Up @@ -51,7 +51,7 @@ def _check_input_parameters_are_valid(
f'attribute of the scale bijector: '
f'{scale.event_ndims_out} != 1.')
if not scale.is_constant_jacobian:
raise ValueError('The scale bijector should be an affine bijector.')
raise ValueError('The scale bijector should be a linear bijector.')
if loc.ndim < 1:
raise ValueError('`loc` must have at least 1 dimension.')
if loc.ndim - 1 > len(batch_shape):
Expand Down Expand Up @@ -151,7 +151,7 @@ def covariance(self) -> Array:
The covariance matrix, of shape `k x k` (broadcasted to match the batch
shape of the distribution).
"""
if isinstance(self.scale, diag_affine.DiagAffine):
if isinstance(self.scale, diag_linear.DiagLinear):
result = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(self.variance())
else:
result = jax.vmap(self.scale.forward, in_axes=-2, out_axes=-2)(
Expand All @@ -161,15 +161,15 @@ def covariance(self) -> Array:

def variance(self) -> Array:
"""Calculates the variance of all one-dimensional marginals."""
if isinstance(self.scale, diag_affine.DiagAffine):
if isinstance(self.scale, diag_linear.DiagLinear):
result = jnp.square(self.scale.diag)
else:
result = jnp.sum(self._scale_matrix * self._scale_matrix, axis=-1)
return jnp.broadcast_to(result, self.batch_shape + self.event_shape)

def stddev(self) -> Array:
"""Calculates the standard deviation (the square root of the variance)."""
if isinstance(self.scale, diag_affine.DiagAffine):
if isinstance(self.scale, diag_linear.DiagLinear):
result = jnp.abs(self.scale.diag)
else:
result = jnp.sqrt(self.variance())
Expand Down Expand Up @@ -216,7 +216,7 @@ def _scale_matrix(d: MultivariateNormalLike) -> Array:
def _has_diagonal_scale(d: MultivariateNormalLike) -> bool:
"""Determines if the scale matrix `A` is diagonal."""
if (isinstance(d, MultivariateNormalFromBijector)
and isinstance(d.scale, diag_affine.DiagAffine)):
and isinstance(d.scale, diag_linear.DiagLinear)):
return True
elif isinstance(d, tfd.MultivariateNormalDiag):
# This does not cover all cases, but we do not have access to the TFP
Expand Down
Loading