From da915f5ec46a57e4eaeef166a5f866ca58650499 Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Mon, 31 Oct 2022 09:50:58 -0700 Subject: [PATCH] Add clipped distributions to distrax. PiperOrigin-RevId: 485076789 --- distrax/__init__.py | 6 + distrax/_src/distributions/clipped.py | 140 ++++++++++++++ distrax/_src/distributions/clipped_test.py | 204 +++++++++++++++++++++ 3 files changed, 350 insertions(+) create mode 100644 distrax/_src/distributions/clipped.py create mode 100644 distrax/_src/distributions/clipped_test.py diff --git a/distrax/__init__.py b/distrax/__init__.py index 2f55000..2139f55 100644 --- a/distrax/__init__.py +++ b/distrax/__init__.py @@ -41,6 +41,9 @@ from distrax._src.distributions.beta import Beta from distrax._src.distributions.categorical import Categorical from distrax._src.distributions.categorical_uniform import CategoricalUniform +from distrax._src.distributions.clipped import ClippedDistribution +from distrax._src.distributions.clipped import ClippedLogistic +from distrax._src.distributions.clipped import ClippedNormal from distrax._src.distributions.deterministic import Deterministic from distrax._src.distributions.dirichlet import Dirichlet from distrax._src.distributions.distribution import Distribution @@ -96,6 +99,9 @@ "Categorical", "CategoricalUniform", "Chain", + "ClippedDistribution", + "ClippedLogistic", + "ClippedNormal", "Deterministic", "DiagLinear", "DiagPlusLowRankLinear", diff --git a/distrax/_src/distributions/clipped.py b/distrax/_src/distributions/clipped.py new file mode 100644 index 0000000..7d6fa9b --- /dev/null +++ b/distrax/_src/distributions/clipped.py @@ -0,0 +1,140 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Clipped distributions.""" + +from typing import Tuple + +import chex +from distrax._src.distributions import distribution as base_distribution +from distrax._src.distributions import logistic +from distrax._src.distributions import normal +from distrax._src.utils import conversion +import jax.numpy as jnp + + +Array = chex.Array +PRNGKey = chex.PRNGKey +Numeric = chex.Numeric +DistributionLike = base_distribution.DistributionLike + + +class ClippedDistribution(base_distribution.Distribution): + """A clipped distribution.""" + + def __init__( + self, + distribution: DistributionLike, + minimum: Numeric, + maximum: Numeric): + """Wraps a distribution clipping samples out of `[minimum, maximum]`. + + The samples outside of `[minimum, maximum]` are clipped to the boundary. + The log probability of samples outside of this range is `-inf`. + + Args: + distribution: a Distrax / TFP distribution to be wrapped. + minimum: can be a `scalar` or `vector`; if a vector, must have fewer dims + than `distribution.batch_shape` and must be broadcastable to it. + maximum: can be a `scalar` or `vector`; if a vector, must have fewer dims + than `distribution.batch_shape` and must be broadcastable to it. + """ + super().__init__() + if distribution.event_shape: + raise ValueError('The wrapped distribution must have event shape ().') + if (jnp.array(minimum).ndim > len(distribution.batch_shape) or + jnp.array(maximum).ndim > len(distribution.batch_shape)): + raise ValueError( + 'The minimum and maximum clipping boundaries must be scalars or' + 'vectors with fewer dimensions as the batch_shape of distribution:' + 'i.e. we can broadcast min/max to batch_shape but not viceversa.') + self._distribution = conversion.as_distribution(distribution) + self._minimum = jnp.broadcast_to(minimum, self._distribution.batch_shape) + self._maximum = jnp.broadcast_to(maximum, self._distribution.batch_shape) + self._log_prob_minimum = self._distribution.log_cdf(minimum) + self._log_prob_maximum = self._distribution.log_survival_function(maximum) + + def _sample_n(self, key: PRNGKey, n: int) -> Array: + """See `Distribution._sample_n`.""" + raw_sample = self._distribution.sample(seed=key, sample_shape=[n]) + return jnp.clip(raw_sample, self._minimum, self._maximum) + + def _sample_n_and_log_prob(self, key: PRNGKey, n: int) -> Tuple[Array, Array]: + """See `Distribution._sample_n_and_log_prob`.""" + samples = self._sample_n(key, n) + return samples, self.log_prob(samples) + + def log_prob(self, value: Array) -> Array: + """See `Distribution.log_prob`.""" + # The log_prob can be used to compute expectations by explicitly integrating + # over the discrete and continuous elements. + # Info about mixed distributions: + # http://www.randomservices.org/random/dist/Mixed.html + log_prob = jnp.where( + jnp.equal(value, self._minimum), + self._log_prob_minimum, + jnp.where(jnp.equal(value, self._maximum), + self._log_prob_maximum, + self._distribution.log_prob(value))) + # Giving -inf log_prob outside the boundaries. + return jnp.where( + jnp.logical_or(value < self._minimum, value > self._maximum), + -jnp.inf, + log_prob) + + @property + def minimum(self) -> Numeric: + return self._minimum + + @property + def maximum(self) -> Numeric: + return self._maximum + + @property + def distribution(self) -> DistributionLike: + return self._distribution + + @property + def event_shape(self) -> Tuple[int, ...]: + return () + + @property + def batch_shape(self) -> Tuple[int, ...]: + return self._distribution.batch_shape + + def __getitem__(self, index) -> 'ClippedDistribution': + """See `Distribution.__getitem__`.""" + index = base_distribution.to_batch_shape_index(self.batch_shape, index) + return ClippedDistribution( + distribution=self.distribution[index], + minimum=self.minimum[index], + maximum=self.maximum[index]) + + +class ClippedNormal(ClippedDistribution): + """A clipped normal distribution.""" + + def __init__( + self, loc: Numeric, scale: Numeric, minimum: Numeric, maximum: Numeric): + distribution = normal.Normal(loc=loc, scale=scale) + super().__init__(distribution, minimum=minimum, maximum=maximum) + + +class ClippedLogistic(ClippedDistribution): + """A clipped logistic distribution.""" + + def __init__( + self, loc: Numeric, scale: Numeric, minimum: Numeric, maximum: Numeric): + distribution = logistic.Logistic(loc=loc, scale=scale) + super().__init__(distribution, minimum=minimum, maximum=maximum) diff --git a/distrax/_src/distributions/clipped_test.py b/distrax/_src/distributions/clipped_test.py new file mode 100644 index 0000000..54e8e55 --- /dev/null +++ b/distrax/_src/distributions/clipped_test.py @@ -0,0 +1,204 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `clipped.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +from distrax._src.distributions import clipped +from distrax._src.distributions import logistic +from distrax._src.distributions import normal +import jax +import jax.numpy as jnp +import numpy as np + + +MINIMUM = -1.0 +MAXIMUM = 1.0 +LOC = MINIMUM +SCALE = 0.1 +SIZE = 3 + + +class ClippedDistributionTest(parameterized.TestCase): + + @parameterized.parameters([ + [clipped.ClippedLogistic, logistic.Logistic], + [clipped.ClippedNormal, normal.Normal], + ]) + def test_clipped_logprob(self, factory, unclipped_factory): + distribution = factory( + loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM) + unclipped = unclipped_factory(loc=LOC, scale=SCALE) + + np.testing.assert_allclose( + unclipped.log_prob(0.0), + distribution.log_prob(0.0)) + np.testing.assert_allclose( + unclipped.log_prob(0.8), + distribution.log_prob(0.8)) + + # Testing outside of the boundary. + self.assertEqual(-np.inf, distribution.log_prob(MINIMUM - 0.1)) + self.assertEqual(-np.inf, distribution.log_prob(MAXIMUM + 0.1)) + + @parameterized.parameters([ + [clipped.ClippedLogistic, logistic.Logistic], + [clipped.ClippedNormal, normal.Normal], + ]) + def test_batched_clipped_logprob(self, factory, unclipped_factory): + distribution = factory( + loc=jnp.array([LOC]*SIZE), + scale=jnp.array([SCALE]*SIZE), + minimum=MINIMUM, + maximum=MAXIMUM) + unclipped = unclipped_factory(loc=LOC, scale=SCALE) + + np.testing.assert_allclose( + unclipped.log_prob(jnp.array([0.0]*SIZE)), + distribution.log_prob(jnp.array([0.0]*SIZE))) + np.testing.assert_allclose( + unclipped.log_prob(jnp.array([0.8]*SIZE)), + distribution.log_prob(jnp.array([0.8]*SIZE))) + + # Testing outside of the boundary. + np.testing.assert_allclose( + -np.inf, distribution.log_prob(jnp.array([MINIMUM - 0.1]*SIZE))) + np.testing.assert_allclose( + -np.inf, distribution.log_prob(jnp.array([MAXIMUM + 0.1]*SIZE))) + + @parameterized.parameters([ + [clipped.ClippedLogistic, logistic.Logistic], + [clipped.ClippedNormal, normal.Normal], + ]) + def test_clipped_sampled_and_logprob(self, factory, unclipped_factory): + distribution = factory( + loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM) + unclipped = unclipped_factory(loc=LOC, scale=SCALE) + + for rng in jax.random.split(jax.random.PRNGKey(42), 5): + sample, log_prob = distribution.sample_and_log_prob(seed=rng) + unclipped_sample, unclipped_log_prob = unclipped.sample_and_log_prob( + seed=rng) + if float(unclipped_sample) > MAXIMUM: + np.testing.assert_allclose(sample, MAXIMUM, atol=1e-5) + elif float(unclipped_sample) < MINIMUM: + np.testing.assert_allclose(sample, MINIMUM, atol=1e-5) + else: + np.testing.assert_allclose(sample, unclipped_sample, atol=1e-5) + np.testing.assert_allclose(log_prob, unclipped_log_prob, atol=1e-5) + + @parameterized.parameters([ + [clipped.ClippedLogistic, logistic.Logistic], + [clipped.ClippedNormal, normal.Normal], + ]) + def test_clipped_sample(self, factory, unclipped_factory): + distribution = factory( + loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM) + unclipped = unclipped_factory(loc=LOC, scale=SCALE) + + for rng in jax.random.split(jax.random.PRNGKey(42), 5): + sample = distribution.sample(seed=rng) + unclipped_sample = unclipped.sample(seed=rng) + if float(unclipped_sample) > MAXIMUM: + np.testing.assert_allclose(sample, MAXIMUM, atol=1e-5) + elif float(unclipped_sample) < MINIMUM: + np.testing.assert_allclose(sample, MINIMUM, atol=1e-5) + else: + np.testing.assert_allclose(sample, unclipped_sample, atol=1e-5) + + @parameterized.parameters([ + [clipped.ClippedLogistic], + [clipped.ClippedNormal], + ]) + def test_extremes(self, factory): + minimum = -1.0 + maximum = 1.0 + scale = 0.01 + + # Using extreme loc. + distribution = factory( + loc=minimum, scale=scale, minimum=minimum, maximum=maximum) + self.assertTrue(np.isfinite(distribution.log_prob(minimum))) + self.assertTrue(np.isfinite(distribution.log_prob(maximum))) + + distribution = factory( + loc=maximum, scale=scale, minimum=minimum, maximum=maximum) + self.assertTrue(np.isfinite(distribution.log_prob(minimum))) + self.assertTrue(np.isfinite(distribution.log_prob(maximum))) + + def test_jitable(self): + minimum = -1.0 + maximum = 1.0 + loc = minimum + scale = 0.1 + + @jax.jit + def jitted_function(event, dist): + return dist.log_prob(event) + + dist = clipped.ClippedLogistic( + loc=loc, scale=scale, minimum=minimum, maximum=maximum) + event = dist.sample(seed=jax.random.PRNGKey(4242)) + log_prob = dist.log_prob(event) + jitted_log_prob = jitted_function(event, dist) + + chex.assert_trees_all_close( + jitted_log_prob, log_prob, atol=1e-4, rtol=1e-4) + + def test_properties(self): + dist = clipped.ClippedLogistic( + loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM) + np.testing.assert_allclose(dist.minimum, MINIMUM, atol=1e-5) + np.testing.assert_allclose(dist.maximum, MAXIMUM, atol=1e-5) + dist = clipped.ClippedLogistic( + loc=jnp.array([LOC]*SIZE), + scale=jnp.array([SCALE]*SIZE), + minimum=MINIMUM, + maximum=MAXIMUM) + np.testing.assert_allclose(dist.minimum, MINIMUM, atol=1e-5) + np.testing.assert_allclose(dist.maximum, MAXIMUM, atol=1e-5) + + def test_min_max_broadcasting(self): + dist = clipped.ClippedLogistic( + loc=jnp.array([LOC]*SIZE), + scale=jnp.array([SCALE]*SIZE), + minimum=MINIMUM, + maximum=MAXIMUM) + self.assertEqual(dist.minimum.shape, (SIZE,)) + self.assertEqual(dist.minimum.shape, (SIZE,)) + + def test_batch_shape(self): + dist = clipped.ClippedLogistic( + loc=jnp.array([LOC]*SIZE), + scale=jnp.array([SCALE]*SIZE), + minimum=MINIMUM, + maximum=MAXIMUM) + self.assertEqual(dist.batch_shape, (SIZE,)) + self.assertEqual(dist.batch_shape, (SIZE,)) + + def test_event_shape(self): + dist = clipped.ClippedLogistic( + loc=jnp.array([LOC]*SIZE), + scale=jnp.array([SCALE]*SIZE), + minimum=jnp.array([MINIMUM]*SIZE), + maximum=jnp.array([MAXIMUM]*SIZE)) + self.assertEqual(dist.event_shape, ()) + self.assertEqual(dist.event_shape, ()) + + +if __name__ == '__main__': + absltest.main()