Skip to content

Commit

Permalink
Add a straight-through gradient wrapper method.
Browse files Browse the repository at this point in the history
Categorical distributions wrapped this way will behave the same in the forward
pass, but use a biased gradient computed on `self.probs` in the backward pass.

PiperOrigin-RevId: 409150269
  • Loading branch information
DistraxDev authored and DistraxDev committed Nov 23, 2021
1 parent af9d163 commit dab3796
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 0 deletions.
64 changes: 64 additions & 0 deletions distrax/_src/distributions/straight_through.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Straight-through gradient sampling distribution."""
import distrax
from distrax._src.distributions import categorical
import jax


def straight_through_wrapper( # pylint: disable=invalid-name
Distribution,
) -> distrax.DistributionLike:
"""Wrap a distribution to use straight-through gradient for samples."""

def sample(self, seed, sample_shape=()): # pylint: disable=g-doc-args
"""Sampling with straight through biased gradient estimator.
Sample a value from the distribution, but backpropagate through the
underlying probability to compute the gradient.
References:
[1] Yoshua Bengio, Nicholas Léonard, Aaron Courville, Estimating or
Propagating Gradients Through Stochastic Neurons for Conditional
Computation, https://arxiv.org/abs/1308.3432
Args:
seed: a random seed.
sample_shape: the shape of the required sample.
Returns:
A sample with straight-through gradient.
"""
# pylint: disable=protected-access
obj = Distribution(probs=self._probs, logits=self._logits)
assert isinstance(obj, categorical.Categorical)
sample = obj.sample(seed=seed, sample_shape=sample_shape)
probs = obj.probs
padded_probs = _pad(probs, sample.shape)

# Keep sample unchanged, but add gradient through probs.
sample += padded_probs - jax.lax.stop_gradient(padded_probs)
return sample

def _pad(probs, shape):
"""Grow probs to have the same number of dimensions as shape."""
while len(probs.shape) < len(shape):
probs = probs[None]
return probs

parent_name = Distribution.__name__
# Return a new object, overriding sample.
return type('StraighThrough' + parent_name, (Distribution,),
{'sample': sample})
146 changes: 146 additions & 0 deletions distrax/_src/distributions/straight_through_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `straight_through.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
from distrax._src.distributions import one_hot_categorical
from distrax._src.distributions import straight_through
from distrax._src.utils import equivalence
from distrax._src.utils import math
import jax
import jax.numpy as jnp
import numpy as np


RTOL = 2e-3


class StraightThroughTest(equivalence.EquivalenceTest, parameterized.TestCase):

def setUp(self):
# pylint: disable=too-many-function-args
super().setUp(straight_through.straight_through_wrapper(
one_hot_categorical.OneHotCategorical))
self.assertion_fn = lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL)

# TODO(https://github.com/deepmind/chex/issues/115)
@chex.all_variants(with_pmap=False)
@parameterized.named_parameters(
('1d logits, no shape', {'logits': [0.0, 1.0, -0.5]}, ()),
('1d probs, no shape', {'probs': [0.2, 0.5, 0.3]}, ()),
('1d logits, int shape', {'logits': [0.0, 1.0, -0.5]}, 1),
('1d probs, int shape', {'probs': [0.2, 0.5, 0.3]}, 1),
('1d logits, 1-tuple shape', {'logits': [0.0, 1.0, -0.5]}, (1,)),
('1d probs, 1-tuple shape', {'probs': [0.2, 0.5, 0.3]}, (1,)),
('1d logits, 2-tuple shape', {'logits': [0.0, 50., -0.5]}, (5, 4)),
('1d probs, 2-tuple shape', {'probs': [0.01, 0.99, 0.]}, (5, 4)),
('2d logits, no shape', {'logits': [[0.0, 1.0, -0.5],
[-0.1, 0.3, 0.0]]}, ()),
('2d probs, no shape', {'probs': [[0.1, 0.4, 0.5],
[0.5, 0.25, 0.25]]}, ()),
('2d logits, int shape', {'logits': [[0.0, 50.0, -0.5],
[-0.1, -0.3, 0.2]]}, 4),
('2d probs, int shape', {'probs': [[0.005, 0.005, 0.99],
[0.99, 0., 0.01]]}, 4),
('2d logits, 1-tuple shape', {'logits': [[0.0, 1.0, -0.5],
[-0.1, 0.3, 200.0]]}, (5,)),
('2d probs, 1-tuple shape', {'probs': [[0., 0.01, 0.99],
[0., 0.99, 0.01]]}, (5,)),
('2d logits, 2-tuple shape', {'logits': [[0.0, 1.0, -0.5],
[-0.1, 0.3, 1000.0]]}, (5, 4)),
('2d probs, 2-tuple shape', {'probs': [[0.01, 0.99, 0.],
[0.99, 0., 0.01]]}, (5, 4)),
)
def test_sample(self, dist_params, sample_shape):

def loss(dist_params, dist_cls, sample_shape):
"""Loss on sample, used both for distrax and TFP."""
# Sample.
dist = dist_cls(**dist_params)
sample_fn = dist.sample
if hasattr(self, 'variant'):
sample_fn = self.variant(static_argnames=('sample_shape',))(sample_fn)
sample = sample_fn(seed=self.key, sample_shape=sample_shape)

return jnp.sum((sample)**2).astype(jnp.float32), sample

# TFP softmax gradient.
def straight_through_tfp_loss(dist_params, dist_cls, sample_shape):
"""Loss on a straight-through gradient of the tfp sample."""
# Distrax normalises the distribution parameters. We want to make sure
# that they are normalised for tfp too, or the gradient might differ.
try:
dist_params['logits'] = math.normalize(logits=dist_params['logits'])
except KeyError:
dist_params['probs'] = math.normalize(probs=dist_params['probs'])

# Sample.
dist = dist_cls(**dist_params)
sample_fn = dist.sample
if hasattr(self, 'variant'):
sample_fn = self.variant(static_argnames=('sample_shape',))(sample_fn)
sample = sample_fn(seed=self.key, sample_shape=sample_shape)

# Straight-through gradient.
def _pad(probs, shape):
if isinstance(shape, int):
return probs
while len(probs.shape) < len(shape):
probs = probs[None]
return probs
probs = dist.probs_parameter()
padded_probs = _pad(probs, sample_shape)
sample += padded_probs - jax.lax.stop_gradient(padded_probs)

return jnp.sum((sample)**2).astype(jnp.float32), sample

# Straight-through gradient and sample.
sample_grad, sample = jax.grad(loss, has_aux=True)(dist_params,
self.distrax_cls,
sample_shape)
# TFP gradient (zero) and sample.
tfp_sample_grad, tfp_sample = jax.grad(loss, has_aux=True)(dist_params,
self.tfp_cls,
sample_shape)
# TFP straight-through gradient and sample.
tfp_st_sample_grad, tfp_st_sample = jax.grad(straight_through_tfp_loss,
has_aux=True)(dist_params,
self.tfp_cls,
sample_shape)

# TEST: the samples have the same size, and the straight-through gradient
# doesn't affect the tfp sample.
chex.assert_equal_shape((sample, tfp_sample))
self.assertion_fn(tfp_sample, tfp_st_sample)
# TEST: the TFP gradient is zero.
assert (jnp.asarray(*tfp_sample_grad.values()) == 0).all()
# TEST: the TFP straight-through gradient is non zero.
assert (jnp.asarray(*tfp_st_sample_grad.values()) != 0).any()
# Test that the TFP straight-through gradient is equal to the one from
# distrax when the samples from distrax and tfp are the same (due to
# stochasticity the samples can differ - we are using skewed distributions
# on purpose in the parametrization of the test to make sure that the
# samples match most of the time).
sample_grad_v = jnp.stack(jnp.array(*sample_grad.values()))
tfp_st_sample_grad_v = jnp.stack(jnp.array(*tfp_st_sample_grad.values()))
if np.all(sample == tfp_st_sample):
self.assertion_fn(sample_grad_v, tfp_st_sample_grad_v)


if __name__ == '__main__':
absltest.main()

0 comments on commit dab3796

Please sign in to comment.