Skip to content

Commit

Permalink
Demote DiagAffine to DiagLinear.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 434523936
  • Loading branch information
gpapamak authored and DistraxDev committed Mar 15, 2022
1 parent 6f0551d commit 4b58e48
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Diagonal affine bijector."""
"""Diagonal linear bijector."""

from distrax._src.bijectors import bijector as base
from distrax._src.bijectors import block
Expand All @@ -22,31 +22,11 @@
Array = base.Array


def _check_shapes_are_valid(diag: Array, bias: Array):
"""Checks array shapes are valid, raises `ValueError` if not."""
if diag.ndim < 1:
raise ValueError("`diag` must have at least one dimension.")
if bias.ndim < 1:
raise ValueError("`bias` must have at least one dimension.")
if bias.shape[-1] != diag.shape[-1]:
raise ValueError(
f"Both `bias` and `diag` must have the same number of dimensions; got "
f"`bias.shape[-1]={bias.shape[-1]}` and "
f"`diag.shape[-1]={diag.shape[-1]}`.")
try:
jnp.broadcast_shapes(diag.shape, bias.shape)
except ValueError:
raise ValueError(
f"The shapes of `bias` and `diag` are not broadcastable; got "
f"`bias.shape={bias.shape}` and `diag.shape={diag.shape}`.") from None
class DiagLinear(block.Block):
"""Linear bijector with a diagonal weight matrix.

class DiagAffine(block.Block):
"""Affine bijector with a diagonal weight matrix.
The bijector is defined as `f(x) = Ax + b` where `A` is a `DxD` diagonal
matrix and `b` is a `D`-dimensional vector. Additional dimensions, if any,
index batches.
The bijector is defined as `f(x) = Ax` where `A` is a `DxD` diagonal matrix.
Additional dimensions, if any, index batches.
The Jacobian determinant is trivially computed by taking the product of the
diagonal entries in `A`. The inverse transformation `x = f^{-1}(y)` is
Expand All @@ -58,40 +38,31 @@ class DiagAffine(block.Block):
invertible.
"""

def __init__(self, diag: Array, bias: Array):
def __init__(self, diag: Array):
"""Initializes the bijector.
Args:
diag: a vector of length D, the diagonal of matrix `A`. Can also be a
batch of such vectors.
bias: the vector `b` in `Ax + b`. Can also be a batch of such vectors.
"""
_check_shapes_are_valid(diag=diag, bias=bias)
bijector = scalar_affine.ScalarAffine(shift=bias, scale=diag)
if diag.ndim < 1:
raise ValueError("`diag` must have at least one dimension.")
bijector = scalar_affine.ScalarAffine(shift=0., scale=diag)
super().__init__(bijector=bijector, ndims=1)
self._diag = diag
self._bias = bias

@property
def diag(self) -> Array:
"""Vector of length D, the diagonal of matrix `A`."""
return self._diag

@property
def bias(self) -> Array:
"""The bias `b` of the transformation."""
return self._bias

@property
def matrix(self) -> Array:
"""The full matrix `A`."""
return jnp.vectorize(jnp.diag, signature="(k)->(k,k)")(self.diag)

def same_as(self, other: base.Bijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is DiagAffine: # pylint: disable=unidiomatic-typecheck
return all((
self.diag is other.diag,
self.bias is other.bias,
))
if type(other) is DiagLinear: # pylint: disable=unidiomatic-typecheck
return self.diag is other.diag
return False
Original file line number Diff line number Diff line change
Expand Up @@ -12,83 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `diag_affine.py`."""
"""Tests for `diag_linear.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
from distrax._src.bijectors.diag_affine import DiagAffine
from distrax._src.bijectors.diag_linear import DiagLinear
from distrax._src.bijectors.tanh import Tanh
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np


class DiagAffineTest(parameterized.TestCase):
class DiagLinearTest(parameterized.TestCase):

def test_jacobian_is_constant_property(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
self.assertTrue(bij.is_constant_jacobian)
self.assertTrue(bij.is_constant_log_det)

def test_properties(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
np.testing.assert_allclose(bij.diag, np.ones(4), atol=1e-6)
np.testing.assert_allclose(bij.bias, np.zeros((4,)), atol=1e-6)
np.testing.assert_allclose(bij.matrix, np.eye(4), atol=1e-6)

@parameterized.named_parameters(
('diag is 0d', {'diag': np.ones(()),
'bias': np.zeros((4,))}),
('bias is 0d', {'diag': np.ones((4,)),
'bias': np.zeros(())}),
('inconsistent dim', {'diag': np.ones((3,)),
'bias': np.zeros((4,))}),
('not broadcastable', {'diag': np.ones((3, 4)),
'bias': np.zeros((2, 4))}),
)
def test_raises_with_invalid_parameters(self, params):
def test_raises_with_invalid_parameters(self):
with self.assertRaises(ValueError):
DiagAffine(**params)
DiagLinear(diag=np.ones(()))

@chex.all_variants
@parameterized.parameters(
((5,), (5,), (5,)),
((5,), (), ()),
((), (5,), ()),
((), (), (5,)),
((5,), (5,)),
((5,), ()),
((), (5,)),
)
def test_batched_parameters(self, diag_batch_shape, bias_batch_shape,
input_batch_shape):
def test_batched_parameters(self, diag_batch_shape, input_batch_shape):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
diag = jax.random.uniform(next(prng), diag_batch_shape + (4,)) + 0.5
bias = jax.random.normal(next(prng), bias_batch_shape + (4,))
bij = DiagAffine(diag, bias)
bij = DiagLinear(diag)

x = jax.random.normal(next(prng), input_batch_shape + (4,))
y, logdet_fwd = self.variant(bij.forward_and_log_det)(x)
z, logdet_inv = self.variant(bij.inverse_and_log_det)(x)

output_batch_shape = jnp.broadcast_shapes(
diag_batch_shape, bias_batch_shape, input_batch_shape)
diag_batch_shape, input_batch_shape)

self.assertEqual(y.shape, output_batch_shape + (4,))
self.assertEqual(z.shape, output_batch_shape + (4,))
self.assertEqual(logdet_fwd.shape, output_batch_shape)
self.assertEqual(logdet_inv.shape, output_batch_shape)

diag = jnp.broadcast_to(diag, output_batch_shape + (4,)).reshape((-1, 4))
bias = jnp.broadcast_to(bias, output_batch_shape + (4,)).reshape((-1, 4))
x = jnp.broadcast_to(x, output_batch_shape + (4,)).reshape((-1, 4))
y = y.reshape((-1, 4))
z = z.reshape((-1, 4))
logdet_fwd = logdet_fwd.flatten()
logdet_inv = logdet_inv.flatten()

for i in range(np.prod(output_batch_shape)):
bij = DiagAffine(diag=diag[i], bias=bias[i])
bij = DiagLinear(diag=diag[i])
this_y, this_logdet_fwd = self.variant(bij.forward_and_log_det)(x[i])
this_z, this_logdet_inv = self.variant(bij.inverse_and_log_det)(x[i])
np.testing.assert_allclose(this_y, y[i], atol=1e-6)
Expand All @@ -102,9 +87,7 @@ def test_batched_parameters(self, diag_batch_shape, bias_batch_shape,
{'batch_shape': (2, 3), 'param_shape': (3,)},
)
def test_identity_initialization(self, batch_shape, param_shape):
bij = DiagAffine(
diag=jnp.ones(param_shape + (4,)),
bias=jnp.zeros(param_shape + (4,)))
bij = DiagLinear(diag=jnp.ones(param_shape + (4,)))
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jax.random.normal(next(prng), batch_shape + (4,))

Expand All @@ -126,8 +109,7 @@ def test_identity_initialization(self, batch_shape, param_shape):
def test_inverse_methods(self, batch_shape, param_shape):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
diag = jax.random.uniform(next(prng), param_shape + (4,)) + 0.5
bias = jax.random.normal(next(prng), param_shape + (4,))
bij = DiagAffine(diag, bias)
bij = DiagLinear(diag)
x = jax.random.normal(next(prng), batch_shape + (4,))
y, logdet_fwd = self.variant(bij.forward_and_log_det)(x)
x_rec, logdet_inv = self.variant(bij.inverse_and_log_det)(y)
Expand All @@ -138,8 +120,7 @@ def test_inverse_methods(self, batch_shape, param_shape):
def test_forward_jacobian_det(self):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
diag = jax.random.uniform(next(prng), (4,)) + 0.5
bias = jax.random.normal(next(prng), (4,))
bij = DiagAffine(diag, bias)
bij = DiagLinear(diag)

batched_x = jax.random.normal(next(prng), (10, 4))
single_x = jax.random.normal(next(prng), (4,))
Expand All @@ -154,8 +135,7 @@ def test_forward_jacobian_det(self):
def test_inverse_jacobian_det(self):
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
diag = jax.random.uniform(next(prng), (4,)) + 0.5
bias = jax.random.normal(next(prng), (4,))
bij = DiagAffine(diag, bias)
bij = DiagLinear(diag)

batched_y = jax.random.normal(next(prng), (10, 4))
single_y = jax.random.normal(next(prng), (4,))
Expand All @@ -167,7 +147,7 @@ def test_inverse_jacobian_det(self):
np.testing.assert_allclose(logdet, logdet_numerical, atol=5e-4)

def test_raises_on_invalid_input_shape(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
for fn_name, fn in [
('forward', bij.forward),
('inverse', bij.inverse),
Expand All @@ -185,17 +165,17 @@ def test_jittable(self):
def f(x, b):
return b.forward(x)

bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
x = np.zeros((4,))
f(x, bij)

def test_same_as_itself(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
self.assertTrue(bij.same_as(bij))

def test_not_same_as_others(self):
bij = DiagAffine(diag=jnp.ones((4,)), bias=jnp.zeros((4,)))
other = DiagAffine(diag=jnp.ones((4,)), bias=jnp.ones((4,)))
bij = DiagLinear(diag=jnp.ones((4,)))
other = DiagLinear(diag=2. * jnp.ones((4,)))
self.assertFalse(bij.same_as(other))
self.assertFalse(bij.same_as(Tanh()))

Expand Down
4 changes: 2 additions & 2 deletions distrax/_src/distributions/mvn_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional

import chex
from distrax._src.bijectors.diag_affine import DiagAffine
from distrax._src.bijectors.diag_linear import DiagLinear
from distrax._src.distributions import distribution
from distrax._src.distributions.mvn_from_bijector import MultivariateNormalFromBijector
from distrax._src.utils import conversion
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(self,
bias = jnp.zeros_like(loc, shape=loc.shape[-1:])
bias = jnp.expand_dims(
bias, axis=list(range(len(broadcasted_shapes) - bias.ndim)))
scale = DiagAffine(bias=jnp.zeros_like(loc), diag=scale_diag)
scale = DiagLinear(scale_diag)
super().__init__(
loc=loc, scale=scale, batch_shape=broadcasted_shapes[:-1],
dtype=jnp.result_type(loc, scale_diag))
Expand Down
6 changes: 2 additions & 4 deletions distrax/_src/distributions/mvn_diag_plus_low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional

import chex
from distrax._src.bijectors.diag_affine import DiagAffine
from distrax._src.bijectors.diag_linear import DiagLinear
from distrax._src.bijectors.diag_plus_low_rank_affine import DiagPlusLowRankAffine
from distrax._src.distributions import distribution
from distrax._src.distributions.mvn_from_bijector import MultivariateNormalFromBijector
Expand Down Expand Up @@ -170,9 +170,7 @@ def __init__(self,

if scale_u_matrix is None:
# The scale matrix is diagonal.
scale = DiagAffine(
diag=self._scale_diag,
bias=jnp.zeros(loc.shape[-1:], dtype=dtype))
scale = DiagLinear(self._scale_diag)
else:
scale = DiagPlusLowRankAffine(
bias=jnp.zeros(loc.shape[-1:], dtype=dtype),
Expand Down
12 changes: 6 additions & 6 deletions distrax/_src/distributions/mvn_from_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from distrax._src.bijectors import bijector as base_bijector
from distrax._src.bijectors import block
from distrax._src.bijectors import chain
from distrax._src.bijectors import diag_affine
from distrax._src.bijectors import diag_linear
from distrax._src.bijectors import shift
from distrax._src.distributions import independent
from distrax._src.distributions import normal
Expand Down Expand Up @@ -51,7 +51,7 @@ def _check_input_parameters_are_valid(
f'attribute of the scale bijector: '
f'{scale.event_ndims_out} != 1.')
if not scale.is_constant_jacobian:
raise ValueError('The scale bijector should be an affine bijector.')
raise ValueError('The scale bijector should be a linear bijector.')
if loc.ndim < 1:
raise ValueError('`loc` must have at least 1 dimension.')
if loc.ndim - 1 > len(batch_shape):
Expand Down Expand Up @@ -151,7 +151,7 @@ def covariance(self) -> Array:
The covariance matrix, of shape `k x k` (broadcasted to match the batch
shape of the distribution).
"""
if isinstance(self.scale, diag_affine.DiagAffine):
if isinstance(self.scale, diag_linear.DiagLinear):
result = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(self.variance())
else:
result = jax.vmap(self.scale.forward, in_axes=-2, out_axes=-2)(
Expand All @@ -161,15 +161,15 @@ def covariance(self) -> Array:

def variance(self) -> Array:
"""Calculates the variance of all one-dimensional marginals."""
if isinstance(self.scale, diag_affine.DiagAffine):
if isinstance(self.scale, diag_linear.DiagLinear):
result = jnp.square(self.scale.diag)
else:
result = jnp.sum(self._scale_matrix * self._scale_matrix, axis=-1)
return jnp.broadcast_to(result, self.batch_shape + self.event_shape)

def stddev(self) -> Array:
"""Calculates the standard deviation (the square root of the variance)."""
if isinstance(self.scale, diag_affine.DiagAffine):
if isinstance(self.scale, diag_linear.DiagLinear):
result = jnp.abs(self.scale.diag)
else:
result = jnp.sqrt(self.variance())
Expand Down Expand Up @@ -216,7 +216,7 @@ def _scale_matrix(d: MultivariateNormalLike) -> Array:
def _has_diagonal_scale(d: MultivariateNormalLike) -> bool:
"""Determines if the scale matrix `A` is diagonal."""
if (isinstance(d, MultivariateNormalFromBijector)
and isinstance(d.scale, diag_affine.DiagAffine)):
and isinstance(d.scale, diag_linear.DiagLinear)):
return True
elif isinstance(d, tfd.MultivariateNormalDiag):
# This does not cover all cases, but we do not have access to the TFP
Expand Down
Loading

0 comments on commit 4b58e48

Please sign in to comment.