Skip to content

Commit

Permalink
Add clipped distributions to distrax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 485076789
  • Loading branch information
mtthss authored and DistraxDev committed Nov 1, 2022
1 parent 0485679 commit da915f5
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 0 deletions.
6 changes: 6 additions & 0 deletions distrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
from distrax._src.distributions.beta import Beta
from distrax._src.distributions.categorical import Categorical
from distrax._src.distributions.categorical_uniform import CategoricalUniform
from distrax._src.distributions.clipped import ClippedDistribution
from distrax._src.distributions.clipped import ClippedLogistic
from distrax._src.distributions.clipped import ClippedNormal
from distrax._src.distributions.deterministic import Deterministic
from distrax._src.distributions.dirichlet import Dirichlet
from distrax._src.distributions.distribution import Distribution
Expand Down Expand Up @@ -96,6 +99,9 @@
"Categorical",
"CategoricalUniform",
"Chain",
"ClippedDistribution",
"ClippedLogistic",
"ClippedNormal",
"Deterministic",
"DiagLinear",
"DiagPlusLowRankLinear",
Expand Down
140 changes: 140 additions & 0 deletions distrax/_src/distributions/clipped.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Clipped distributions."""

from typing import Tuple

import chex
from distrax._src.distributions import distribution as base_distribution
from distrax._src.distributions import logistic
from distrax._src.distributions import normal
from distrax._src.utils import conversion
import jax.numpy as jnp


Array = chex.Array
PRNGKey = chex.PRNGKey
Numeric = chex.Numeric
DistributionLike = base_distribution.DistributionLike


class ClippedDistribution(base_distribution.Distribution):
"""A clipped distribution."""

def __init__(
self,
distribution: DistributionLike,
minimum: Numeric,
maximum: Numeric):
"""Wraps a distribution clipping samples out of `[minimum, maximum]`.
The samples outside of `[minimum, maximum]` are clipped to the boundary.
The log probability of samples outside of this range is `-inf`.
Args:
distribution: a Distrax / TFP distribution to be wrapped.
minimum: can be a `scalar` or `vector`; if a vector, must have fewer dims
than `distribution.batch_shape` and must be broadcastable to it.
maximum: can be a `scalar` or `vector`; if a vector, must have fewer dims
than `distribution.batch_shape` and must be broadcastable to it.
"""
super().__init__()
if distribution.event_shape:
raise ValueError('The wrapped distribution must have event shape ().')
if (jnp.array(minimum).ndim > len(distribution.batch_shape) or
jnp.array(maximum).ndim > len(distribution.batch_shape)):
raise ValueError(
'The minimum and maximum clipping boundaries must be scalars or'
'vectors with fewer dimensions as the batch_shape of distribution:'
'i.e. we can broadcast min/max to batch_shape but not viceversa.')
self._distribution = conversion.as_distribution(distribution)
self._minimum = jnp.broadcast_to(minimum, self._distribution.batch_shape)
self._maximum = jnp.broadcast_to(maximum, self._distribution.batch_shape)
self._log_prob_minimum = self._distribution.log_cdf(minimum)
self._log_prob_maximum = self._distribution.log_survival_function(maximum)

def _sample_n(self, key: PRNGKey, n: int) -> Array:
"""See `Distribution._sample_n`."""
raw_sample = self._distribution.sample(seed=key, sample_shape=[n])
return jnp.clip(raw_sample, self._minimum, self._maximum)

def _sample_n_and_log_prob(self, key: PRNGKey, n: int) -> Tuple[Array, Array]:
"""See `Distribution._sample_n_and_log_prob`."""
samples = self._sample_n(key, n)
return samples, self.log_prob(samples)

def log_prob(self, value: Array) -> Array:
"""See `Distribution.log_prob`."""
# The log_prob can be used to compute expectations by explicitly integrating
# over the discrete and continuous elements.
# Info about mixed distributions:
# http://www.randomservices.org/random/dist/Mixed.html
log_prob = jnp.where(
jnp.equal(value, self._minimum),
self._log_prob_minimum,
jnp.where(jnp.equal(value, self._maximum),
self._log_prob_maximum,
self._distribution.log_prob(value)))
# Giving -inf log_prob outside the boundaries.
return jnp.where(
jnp.logical_or(value < self._minimum, value > self._maximum),
-jnp.inf,
log_prob)

@property
def minimum(self) -> Numeric:
return self._minimum

@property
def maximum(self) -> Numeric:
return self._maximum

@property
def distribution(self) -> DistributionLike:
return self._distribution

@property
def event_shape(self) -> Tuple[int, ...]:
return ()

@property
def batch_shape(self) -> Tuple[int, ...]:
return self._distribution.batch_shape

def __getitem__(self, index) -> 'ClippedDistribution':
"""See `Distribution.__getitem__`."""
index = base_distribution.to_batch_shape_index(self.batch_shape, index)
return ClippedDistribution(
distribution=self.distribution[index],
minimum=self.minimum[index],
maximum=self.maximum[index])


class ClippedNormal(ClippedDistribution):
"""A clipped normal distribution."""

def __init__(
self, loc: Numeric, scale: Numeric, minimum: Numeric, maximum: Numeric):
distribution = normal.Normal(loc=loc, scale=scale)
super().__init__(distribution, minimum=minimum, maximum=maximum)


class ClippedLogistic(ClippedDistribution):
"""A clipped logistic distribution."""

def __init__(
self, loc: Numeric, scale: Numeric, minimum: Numeric, maximum: Numeric):
distribution = logistic.Logistic(loc=loc, scale=scale)
super().__init__(distribution, minimum=minimum, maximum=maximum)
204 changes: 204 additions & 0 deletions distrax/_src/distributions/clipped_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `clipped.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
from distrax._src.distributions import clipped
from distrax._src.distributions import logistic
from distrax._src.distributions import normal
import jax
import jax.numpy as jnp
import numpy as np


MINIMUM = -1.0
MAXIMUM = 1.0
LOC = MINIMUM
SCALE = 0.1
SIZE = 3


class ClippedDistributionTest(parameterized.TestCase):

@parameterized.parameters([
[clipped.ClippedLogistic, logistic.Logistic],
[clipped.ClippedNormal, normal.Normal],
])
def test_clipped_logprob(self, factory, unclipped_factory):
distribution = factory(
loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM)
unclipped = unclipped_factory(loc=LOC, scale=SCALE)

np.testing.assert_allclose(
unclipped.log_prob(0.0),
distribution.log_prob(0.0))
np.testing.assert_allclose(
unclipped.log_prob(0.8),
distribution.log_prob(0.8))

# Testing outside of the boundary.
self.assertEqual(-np.inf, distribution.log_prob(MINIMUM - 0.1))
self.assertEqual(-np.inf, distribution.log_prob(MAXIMUM + 0.1))

@parameterized.parameters([
[clipped.ClippedLogistic, logistic.Logistic],
[clipped.ClippedNormal, normal.Normal],
])
def test_batched_clipped_logprob(self, factory, unclipped_factory):
distribution = factory(
loc=jnp.array([LOC]*SIZE),
scale=jnp.array([SCALE]*SIZE),
minimum=MINIMUM,
maximum=MAXIMUM)
unclipped = unclipped_factory(loc=LOC, scale=SCALE)

np.testing.assert_allclose(
unclipped.log_prob(jnp.array([0.0]*SIZE)),
distribution.log_prob(jnp.array([0.0]*SIZE)))
np.testing.assert_allclose(
unclipped.log_prob(jnp.array([0.8]*SIZE)),
distribution.log_prob(jnp.array([0.8]*SIZE)))

# Testing outside of the boundary.
np.testing.assert_allclose(
-np.inf, distribution.log_prob(jnp.array([MINIMUM - 0.1]*SIZE)))
np.testing.assert_allclose(
-np.inf, distribution.log_prob(jnp.array([MAXIMUM + 0.1]*SIZE)))

@parameterized.parameters([
[clipped.ClippedLogistic, logistic.Logistic],
[clipped.ClippedNormal, normal.Normal],
])
def test_clipped_sampled_and_logprob(self, factory, unclipped_factory):
distribution = factory(
loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM)
unclipped = unclipped_factory(loc=LOC, scale=SCALE)

for rng in jax.random.split(jax.random.PRNGKey(42), 5):
sample, log_prob = distribution.sample_and_log_prob(seed=rng)
unclipped_sample, unclipped_log_prob = unclipped.sample_and_log_prob(
seed=rng)
if float(unclipped_sample) > MAXIMUM:
np.testing.assert_allclose(sample, MAXIMUM, atol=1e-5)
elif float(unclipped_sample) < MINIMUM:
np.testing.assert_allclose(sample, MINIMUM, atol=1e-5)
else:
np.testing.assert_allclose(sample, unclipped_sample, atol=1e-5)
np.testing.assert_allclose(log_prob, unclipped_log_prob, atol=1e-5)

@parameterized.parameters([
[clipped.ClippedLogistic, logistic.Logistic],
[clipped.ClippedNormal, normal.Normal],
])
def test_clipped_sample(self, factory, unclipped_factory):
distribution = factory(
loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM)
unclipped = unclipped_factory(loc=LOC, scale=SCALE)

for rng in jax.random.split(jax.random.PRNGKey(42), 5):
sample = distribution.sample(seed=rng)
unclipped_sample = unclipped.sample(seed=rng)
if float(unclipped_sample) > MAXIMUM:
np.testing.assert_allclose(sample, MAXIMUM, atol=1e-5)
elif float(unclipped_sample) < MINIMUM:
np.testing.assert_allclose(sample, MINIMUM, atol=1e-5)
else:
np.testing.assert_allclose(sample, unclipped_sample, atol=1e-5)

@parameterized.parameters([
[clipped.ClippedLogistic],
[clipped.ClippedNormal],
])
def test_extremes(self, factory):
minimum = -1.0
maximum = 1.0
scale = 0.01

# Using extreme loc.
distribution = factory(
loc=minimum, scale=scale, minimum=minimum, maximum=maximum)
self.assertTrue(np.isfinite(distribution.log_prob(minimum)))
self.assertTrue(np.isfinite(distribution.log_prob(maximum)))

distribution = factory(
loc=maximum, scale=scale, minimum=minimum, maximum=maximum)
self.assertTrue(np.isfinite(distribution.log_prob(minimum)))
self.assertTrue(np.isfinite(distribution.log_prob(maximum)))

def test_jitable(self):
minimum = -1.0
maximum = 1.0
loc = minimum
scale = 0.1

@jax.jit
def jitted_function(event, dist):
return dist.log_prob(event)

dist = clipped.ClippedLogistic(
loc=loc, scale=scale, minimum=minimum, maximum=maximum)
event = dist.sample(seed=jax.random.PRNGKey(4242))
log_prob = dist.log_prob(event)
jitted_log_prob = jitted_function(event, dist)

chex.assert_trees_all_close(
jitted_log_prob, log_prob, atol=1e-4, rtol=1e-4)

def test_properties(self):
dist = clipped.ClippedLogistic(
loc=LOC, scale=SCALE, minimum=MINIMUM, maximum=MAXIMUM)
np.testing.assert_allclose(dist.minimum, MINIMUM, atol=1e-5)
np.testing.assert_allclose(dist.maximum, MAXIMUM, atol=1e-5)
dist = clipped.ClippedLogistic(
loc=jnp.array([LOC]*SIZE),
scale=jnp.array([SCALE]*SIZE),
minimum=MINIMUM,
maximum=MAXIMUM)
np.testing.assert_allclose(dist.minimum, MINIMUM, atol=1e-5)
np.testing.assert_allclose(dist.maximum, MAXIMUM, atol=1e-5)

def test_min_max_broadcasting(self):
dist = clipped.ClippedLogistic(
loc=jnp.array([LOC]*SIZE),
scale=jnp.array([SCALE]*SIZE),
minimum=MINIMUM,
maximum=MAXIMUM)
self.assertEqual(dist.minimum.shape, (SIZE,))
self.assertEqual(dist.minimum.shape, (SIZE,))

def test_batch_shape(self):
dist = clipped.ClippedLogistic(
loc=jnp.array([LOC]*SIZE),
scale=jnp.array([SCALE]*SIZE),
minimum=MINIMUM,
maximum=MAXIMUM)
self.assertEqual(dist.batch_shape, (SIZE,))
self.assertEqual(dist.batch_shape, (SIZE,))

def test_event_shape(self):
dist = clipped.ClippedLogistic(
loc=jnp.array([LOC]*SIZE),
scale=jnp.array([SCALE]*SIZE),
minimum=jnp.array([MINIMUM]*SIZE),
maximum=jnp.array([MAXIMUM]*SIZE))
self.assertEqual(dist.event_shape, ())
self.assertEqual(dist.event_shape, ())


if __name__ == '__main__':
absltest.main()

0 comments on commit da915f5

Please sign in to comment.