-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add clipped distributions to distrax.
PiperOrigin-RevId: 485360904
- Loading branch information
Showing
3 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Clipped distributions.""" | ||
|
||
from typing import Tuple | ||
|
||
import chex | ||
from distrax._src.distributions import distribution as base_distribution | ||
from distrax._src.distributions import logistic | ||
from distrax._src.distributions import normal | ||
from distrax._src.utils import conversion | ||
import jax.numpy as jnp | ||
|
||
|
||
Array = chex.Array | ||
PRNGKey = chex.PRNGKey | ||
Numeric = chex.Numeric | ||
DistributionLike = base_distribution.DistributionLike | ||
|
||
|
||
class ClippedDistribution(base_distribution.Distribution): | ||
"""A clipped distribution.""" | ||
|
||
def __init__( | ||
self, | ||
distribution: DistributionLike, | ||
minimum: Numeric, | ||
maximum: Numeric): | ||
"""Wraps a distribution clipping samples out of `[minimum, maximum]`. | ||
The samples outside of `[minimum, maximum]` are clipped to the boundary. | ||
The log probability of samples outside of this range is `-inf`. | ||
Args: | ||
distribution: a Distrax / TFP distribution to be wrapped. | ||
minimum: can be a `scalar` or `vector`; if a vector, must have fewer dims | ||
than `distribution.batch_shape` and must be broadcastable to it. | ||
maximum: can be a `scalar` or `vector`; if a vector, must have fewer dims | ||
than `distribution.batch_shape` and must be broadcastable to it. | ||
""" | ||
super().__init__() | ||
if distribution.event_shape: | ||
raise ValueError('The wrapped distribution must have event shape ().') | ||
if (jnp.array(minimum).ndim > len(distribution.batch_shape) or | ||
jnp.array(maximum).ndim > len(distribution.batch_shape)): | ||
raise ValueError( | ||
'The minimum and maximum clipping boundaries must be scalars or' | ||
'vectors with fewer dimensions as the batch_shape of distribution:' | ||
'i.e. we can broadcast min/max to batch_shape but not viceversa.') | ||
self._distribution = conversion.as_distribution(distribution) | ||
self._minimum = jnp.broadcast_to(minimum, self._distribution.batch_shape) | ||
self._maximum = jnp.broadcast_to(maximum, self._distribution.batch_shape) | ||
self._log_prob_minimum = self._distribution.log_cdf(minimum) | ||
self._log_prob_maximum = self._distribution.log_survival_function(maximum) | ||
|
||
def _sample_n(self, key: PRNGKey, n: int) -> Array: | ||
"""See `Distribution._sample_n`.""" | ||
raw_sample = self._distribution.sample(seed=key, sample_shape=[n]) | ||
return jnp.clip(raw_sample, self._minimum, self._maximum) | ||
|
||
def _sample_n_and_log_prob(self, key: PRNGKey, n: int) -> Tuple[Array, Array]: | ||
"""See `Distribution._sample_n_and_log_prob`.""" | ||
samples = self._sample_n(key, n) | ||
return samples, self.log_prob(samples) | ||
|
||
def log_prob(self, value: Array) -> Array: | ||
"""See `Distribution.log_prob`.""" | ||
# The log_prob can be used to compute expectations by explicitly integrating | ||
# over the discrete and continuous elements. | ||
# Info about mixed distributions: | ||
# http://www.randomservices.org/random/dist/Mixed.html | ||
log_prob = jnp.where( | ||
jnp.equal(value, self._minimum), | ||
self._log_prob_minimum, | ||
jnp.where(jnp.equal(value, self._maximum), | ||
self._log_prob_maximum, | ||
self._distribution.log_prob(value))) | ||
# Giving -inf log_prob outside the boundaries. | ||
return jnp.where( | ||
jnp.logical_or(value < self._minimum, value > self._maximum), | ||
-jnp.inf, | ||
log_prob) | ||
|
||
@property | ||
def minimum(self) -> Numeric: | ||
return self._minimum | ||
|
||
@property | ||
def maximum(self) -> Numeric: | ||
return self._maximum | ||
|
||
@property | ||
def distribution(self) -> DistributionLike: | ||
return self._distribution | ||
|
||
@property | ||
def event_shape(self) -> Tuple[int, ...]: | ||
return () | ||
|
||
@property | ||
def batch_shape(self) -> Tuple[int, ...]: | ||
return self._distribution.batch_shape | ||
|
||
def __getitem__(self, index) -> 'ClippedDistribution': | ||
"""See `Distribution.__getitem__`.""" | ||
index = base_distribution.to_batch_shape_index(self.batch_shape, index) | ||
return ClippedDistribution( | ||
distribution=self.distribution[index], | ||
minimum=self.minimum[index], | ||
maximum=self.maximum[index]) | ||
|
||
|
||
class ClippedNormal(ClippedDistribution): | ||
"""A clipped normal distribution.""" | ||
|
||
def __init__( | ||
self, loc: Numeric, scale: Numeric, minimum: Numeric, maximum: Numeric): | ||
distribution = normal.Normal(loc=loc, scale=scale) | ||
super().__init__(distribution, minimum=minimum, maximum=maximum) | ||
|
||
|
||
class ClippedLogistic(ClippedDistribution): | ||
"""A clipped logistic distribution.""" | ||
|
||
def __init__( | ||
self, loc: Numeric, scale: Numeric, minimum: Numeric, maximum: Numeric): | ||
distribution = logistic.Logistic(loc=loc, scale=scale) | ||
super().__init__(distribution, minimum=minimum, maximum=maximum) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for `clipped.py`.""" | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
|
||
import chex | ||
from distrax._src.distributions import clipped | ||
from distrax._src.distributions import logistic | ||
from distrax._src.distributions import normal | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
|
||
MINIMUM = -1.0 | ||
MAXIMUM = 1.0 | ||
LOC = MINIMUM | ||
SCALE = 0.1 | ||
SIZE = 3 | ||
|
||
|
||
class ClippedDistributionTest(parameterized.TestCase): | ||
|
||
@parameterized.parameters([ | ||
[clipped.ClippedLogistic, logistic.Logistic], | ||
[clipped.ClippedNormal, normal.Normal], | ||
]) | ||
def test_clipped_logprob(self, factory, unclipped_factory): | ||
distribution = factory( | ||
loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM) | ||
unclipped = unclipped_factory(loc=LOC, scale=SCALE) | ||
|
||
np.testing.assert_allclose( | ||
unclipped.log_prob(0.0), | ||
distribution.log_prob(0.0)) | ||
np.testing.assert_allclose( | ||
unclipped.log_prob(0.8), | ||
distribution.log_prob(0.8)) | ||
|
||
# Testing outside of the boundary. | ||
self.assertEqual(-np.inf, distribution.log_prob(MINIMUM - 0.1)) | ||
self.assertEqual(-np.inf, distribution.log_prob(MAXIMUM + 0.1)) | ||
|
||
@parameterized.parameters([ | ||
[clipped.ClippedLogistic, logistic.Logistic], | ||
[clipped.ClippedNormal, normal.Normal], | ||
]) | ||
def test_batched_clipped_logprob(self, factory, unclipped_factory): | ||
distribution = factory( | ||
loc=jnp.array([LOC]*SIZE), | ||
scale=jnp.array([SCALE]*SIZE), | ||
minimum=MINIMUM, | ||
maximum=MAXIMUM) | ||
unclipped = unclipped_factory(loc=LOC, scale=SCALE) | ||
|
||
np.testing.assert_allclose( | ||
unclipped.log_prob(jnp.array([0.0]*SIZE)), | ||
distribution.log_prob(jnp.array([0.0]*SIZE))) | ||
np.testing.assert_allclose( | ||
unclipped.log_prob(jnp.array([0.8]*SIZE)), | ||
distribution.log_prob(jnp.array([0.8]*SIZE))) | ||
|
||
# Testing outside of the boundary. | ||
np.testing.assert_allclose( | ||
-np.inf, distribution.log_prob(jnp.array([MINIMUM - 0.1]*SIZE))) | ||
np.testing.assert_allclose( | ||
-np.inf, distribution.log_prob(jnp.array([MAXIMUM + 0.1]*SIZE))) | ||
|
||
@parameterized.parameters([ | ||
[clipped.ClippedLogistic, logistic.Logistic], | ||
[clipped.ClippedNormal, normal.Normal], | ||
]) | ||
def test_clipped_sampled_and_logprob(self, factory, unclipped_factory): | ||
distribution = factory( | ||
loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM) | ||
unclipped = unclipped_factory(loc=LOC, scale=SCALE) | ||
|
||
for rng in jax.random.split(jax.random.PRNGKey(42), 5): | ||
sample, log_prob = distribution.sample_and_log_prob(seed=rng) | ||
unclipped_sample, unclipped_log_prob = unclipped.sample_and_log_prob( | ||
seed=rng) | ||
if float(unclipped_sample) > MAXIMUM: | ||
np.testing.assert_allclose(sample, MAXIMUM, atol=1e-5) | ||
elif float(unclipped_sample) < MINIMUM: | ||
np.testing.assert_allclose(sample, MINIMUM, atol=1e-5) | ||
else: | ||
np.testing.assert_allclose(sample, unclipped_sample, atol=1e-5) | ||
np.testing.assert_allclose(log_prob, unclipped_log_prob, atol=1e-5) | ||
|
||
@parameterized.parameters([ | ||
[clipped.ClippedLogistic, logistic.Logistic], | ||
[clipped.ClippedNormal, normal.Normal], | ||
]) | ||
def test_clipped_sample(self, factory, unclipped_factory): | ||
distribution = factory( | ||
loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM) | ||
unclipped = unclipped_factory(loc=LOC, scale=SCALE) | ||
|
||
for rng in jax.random.split(jax.random.PRNGKey(42), 5): | ||
sample = distribution.sample(seed=rng) | ||
unclipped_sample = unclipped.sample(seed=rng) | ||
if float(unclipped_sample) > MAXIMUM: | ||
np.testing.assert_allclose(sample, MAXIMUM, atol=1e-5) | ||
elif float(unclipped_sample) < MINIMUM: | ||
np.testing.assert_allclose(sample, MINIMUM, atol=1e-5) | ||
else: | ||
np.testing.assert_allclose(sample, unclipped_sample, atol=1e-5) | ||
|
||
@parameterized.parameters([ | ||
[clipped.ClippedLogistic], | ||
[clipped.ClippedNormal], | ||
]) | ||
def test_extremes(self, factory): | ||
minimum = -1.0 | ||
maximum = 1.0 | ||
scale = 0.01 | ||
|
||
# Using extreme loc. | ||
distribution = factory( | ||
loc=minimum, scale=scale, minimum=minimum, maximum=maximum) | ||
self.assertTrue(np.isfinite(distribution.log_prob(minimum))) | ||
self.assertTrue(np.isfinite(distribution.log_prob(maximum))) | ||
|
||
distribution = factory( | ||
loc=maximum, scale=scale, minimum=minimum, maximum=maximum) | ||
self.assertTrue(np.isfinite(distribution.log_prob(minimum))) | ||
self.assertTrue(np.isfinite(distribution.log_prob(maximum))) | ||
|
||
def test_jitable(self): | ||
minimum = -1.0 | ||
maximum = 1.0 | ||
loc = minimum | ||
scale = 0.1 | ||
|
||
@jax.jit | ||
def jitted_function(event, dist): | ||
return dist.log_prob(event) | ||
|
||
dist = clipped.ClippedLogistic( | ||
loc=loc, scale=scale, minimum=minimum, maximum=maximum) | ||
event = dist.sample(seed=jax.random.PRNGKey(4242)) | ||
log_prob = dist.log_prob(event) | ||
jitted_log_prob = jitted_function(event, dist) | ||
|
||
chex.assert_trees_all_close( | ||
jitted_log_prob, log_prob, atol=1e-4, rtol=1e-4) | ||
|
||
def test_properties(self): | ||
dist = clipped.ClippedLogistic( | ||
loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM) | ||
np.testing.assert_allclose(dist.minimum, MINIMUM, atol=1e-5) | ||
np.testing.assert_allclose(dist.maximum, MAXIMUM, atol=1e-5) | ||
dist = clipped.ClippedLogistic( | ||
loc=jnp.array([LOC]*SIZE), | ||
scale=jnp.array([SCALE]*SIZE), | ||
minimum=MINIMUM, | ||
maximum=MAXIMUM) | ||
np.testing.assert_allclose(dist.minimum, MINIMUM, atol=1e-5) | ||
np.testing.assert_allclose(dist.maximum, MAXIMUM, atol=1e-5) | ||
|
||
def test_min_max_broadcasting(self): | ||
dist = clipped.ClippedLogistic( | ||
loc=jnp.array([LOC]*SIZE), | ||
scale=jnp.array([SCALE]*SIZE), | ||
minimum=MINIMUM, | ||
maximum=MAXIMUM) | ||
self.assertEqual(dist.minimum.shape, (SIZE,)) | ||
self.assertEqual(dist.minimum.shape, (SIZE,)) | ||
|
||
def test_batch_shape(self): | ||
dist = clipped.ClippedLogistic( | ||
loc=jnp.array([LOC]*SIZE), | ||
scale=jnp.array([SCALE]*SIZE), | ||
minimum=MINIMUM, | ||
maximum=MAXIMUM) | ||
self.assertEqual(dist.batch_shape, (SIZE,)) | ||
self.assertEqual(dist.batch_shape, (SIZE,)) | ||
|
||
def test_event_shape(self): | ||
dist = clipped.ClippedLogistic( | ||
loc=jnp.array([LOC]*SIZE), | ||
scale=jnp.array([SCALE]*SIZE), | ||
minimum=jnp.array([MINIMUM]*SIZE), | ||
maximum=jnp.array([MAXIMUM]*SIZE)) | ||
self.assertEqual(dist.event_shape, ()) | ||
self.assertEqual(dist.event_shape, ()) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |