Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a straight-through gradient wrapper method. #70

Merged
merged 1 commit into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions distrax/_src/distributions/straight_through.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Straight-through gradient sampling distribution."""
import distrax
from distrax._src.distributions import categorical
import jax


def straight_through_wrapper( # pylint: disable=invalid-name
Distribution,
) -> distrax.DistributionLike:
"""Wrap a distribution to use straight-through gradient for samples."""

def sample(self, seed, sample_shape=()): # pylint: disable=g-doc-args
"""Sampling with straight through biased gradient estimator.

Sample a value from the distribution, but backpropagate through the
underlying probability to compute the gradient.

References:
[1] Yoshua Bengio, Nicholas Léonard, Aaron Courville, Estimating or
Propagating Gradients Through Stochastic Neurons for Conditional
Computation, https://arxiv.org/abs/1308.3432

Args:
seed: a random seed.
sample_shape: the shape of the required sample.

Returns:
A sample with straight-through gradient.
"""
# pylint: disable=protected-access
obj = Distribution(probs=self._probs, logits=self._logits)
assert isinstance(obj, categorical.Categorical)
sample = obj.sample(seed=seed, sample_shape=sample_shape)
probs = obj.probs
padded_probs = _pad(probs, sample.shape)

# Keep sample unchanged, but add gradient through probs.
sample += padded_probs - jax.lax.stop_gradient(padded_probs)
return sample

def _pad(probs, shape):
"""Grow probs to have the same number of dimensions as shape."""
while len(probs.shape) < len(shape):
probs = probs[None]
return probs

parent_name = Distribution.__name__
# Return a new object, overriding sample.
return type('StraighThrough' + parent_name, (Distribution,),
{'sample': sample})
155 changes: 155 additions & 0 deletions distrax/_src/distributions/straight_through_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `straight_through.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
from distrax._src.distributions import one_hot_categorical
from distrax._src.distributions import straight_through
from distrax._src.utils import equivalence
from distrax._src.utils import math
import jax
import jax.numpy as jnp
import numpy as np


RTOL = 2e-3


class StraightThroughTest(equivalence.EquivalenceTest, parameterized.TestCase):

def setUp(self):
# pylint: disable=too-many-function-args
super().setUp(straight_through.straight_through_wrapper(
one_hot_categorical.OneHotCategorical))
self.assertion_fn = lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL)

@chex.all_variants
@parameterized.named_parameters(
('1d logits, no shape', {'logits': [0.0, 1.0, -0.5]}, ()),
('1d probs, no shape', {'probs': [0.2, 0.5, 0.3]}, ()),
('1d logits, int shape', {'logits': [0.0, 1.0, -0.5]}, 1),
('1d probs, int shape', {'probs': [0.2, 0.5, 0.3]}, 1),
('1d logits, 1-tuple shape', {'logits': [0.0, 1.0, -0.5]}, (1,)),
('1d probs, 1-tuple shape', {'probs': [0.2, 0.5, 0.3]}, (1,)),
('1d logits, 2-tuple shape', {'logits': [0.0, 50., -0.5]}, (5, 4)),
('1d probs, 2-tuple shape', {'probs': [0.01, 0.99, 0.]}, (5, 4)),
('2d logits, no shape', {'logits': [[0.0, 1.0, -0.5],
[-0.1, 0.3, 0.0]]}, ()),
('2d probs, no shape', {'probs': [[0.1, 0.4, 0.5],
[0.5, 0.25, 0.25]]}, ()),
('2d logits, int shape', {'logits': [[0.0, 50.0, -0.5],
[-0.1, -0.3, 0.2]]}, 4),
('2d probs, int shape', {'probs': [[0.005, 0.005, 0.99],
[0.99, 0., 0.01]]}, 4),
('2d logits, 1-tuple shape', {'logits': [[0.0, 1.0, -0.5],
[-0.1, 0.3, 200.0]]}, (5,)),
('2d probs, 1-tuple shape', {'probs': [[0., 0.01, 0.99],
[0., 0.99, 0.01]]}, (5,)),
('2d logits, 2-tuple shape', {'logits': [[0.0, 1.0, -0.5],
[-0.1, 0.3, 1000.0]]}, (5, 4)),
('2d probs, 2-tuple shape', {'probs': [[0.01, 0.99, 0.],
[0.99, 0., 0.01]]}, (5, 4)),
)
def test_sample(self, dist_params, sample_shape):

def loss(dist_params, dist_cls, sample_shape):
"""Loss on sample, used both for distrax and TFP."""

# Sample.
dist = dist_cls(**dist_params)
sample_fn = dist.sample

def sample_fn_wrapper(seed, sample_shape):
"""To test with pmap that requires positional arguments."""
return sample_fn(seed=seed, sample_shape=sample_shape)

if hasattr(self, 'variant'):
sample_fn_wrapper = self.variant(static_argnums=(1,))(sample_fn_wrapper)
sample = sample_fn_wrapper(self.key, sample_shape)
return jnp.sum((sample)**2).astype(jnp.float32), sample

# TFP softmax gradient.
def straight_through_tfp_loss(dist_params, dist_cls, sample_shape):
"""Loss on a straight-through gradient of the tfp sample."""
# Distrax normalises the distribution parameters. We want to make sure
# that they are normalised for tfp too, or the gradient might differ.
try:
dist_params['logits'] = math.normalize(logits=dist_params['logits'])
except KeyError:
dist_params['probs'] = math.normalize(probs=dist_params['probs'])

# Sample.
dist = dist_cls(**dist_params)
sample_fn = dist.sample

def sample_fn_wrapper(seed, sample_shape):
"""To test with pmap that requires positional arguments."""
return sample_fn(seed=seed, sample_shape=sample_shape)

if hasattr(self, 'variant'):
sample_fn_wrapper = self.variant(static_argnums=(1,))(sample_fn_wrapper)
sample = sample_fn_wrapper(self.key, sample_shape)

# Straight-through gradient.
def _pad(probs, shape):
if isinstance(shape, int):
return probs
while len(probs.shape) < len(shape):
probs = probs[None]
return probs
probs = dist.probs_parameter()
padded_probs = _pad(probs, sample_shape)
sample += padded_probs - jax.lax.stop_gradient(padded_probs)

return jnp.sum((sample)**2).astype(jnp.float32), sample

# Straight-through gradient and sample.
sample_grad, sample = jax.grad(loss, has_aux=True)(dist_params,
self.distrax_cls,
sample_shape)
# TFP gradient (zero) and sample.
tfp_sample_grad, tfp_sample = jax.grad(loss, has_aux=True)(dist_params,
self.tfp_cls,
sample_shape)
# TFP straight-through gradient and sample.
tfp_st_sample_grad, tfp_st_sample = jax.grad(straight_through_tfp_loss,
has_aux=True)(dist_params,
self.tfp_cls,
sample_shape)

# TEST: the samples have the same size, and the straight-through gradient
# doesn't affect the tfp sample.
chex.assert_equal_shape((sample, tfp_sample))
self.assertion_fn(tfp_sample, tfp_st_sample)
# TEST: the TFP gradient is zero.
assert (jnp.asarray(*tfp_sample_grad.values()) == 0).all()
# TEST: the TFP straight-through gradient is non zero.
assert (jnp.asarray(*tfp_st_sample_grad.values()) != 0).any()
# Test that the TFP straight-through gradient is equal to the one from
# distrax when the samples from distrax and tfp are the same (due to
# stochasticity the samples can differ - we are using skewed distributions
# on purpose in the parametrization of the test to make sure that the
# samples match most of the time).
sample_grad_v = jnp.stack(jnp.array(*sample_grad.values()))
tfp_st_sample_grad_v = jnp.stack(jnp.array(*tfp_st_sample_grad.values()))
if np.all(sample == tfp_st_sample):
self.assertion_fn(sample_grad_v, tfp_st_sample_grad_v)


if __name__ == '__main__':
absltest.main()