Skip to content

Commit

Permalink
An assortment of QWR improvements.
Browse files Browse the repository at this point in the history
- Multiple value network synchronizations in one epoch.
- Softmax and logsumexp Q-value aggregation.
- LayerNorm observation normalization as in the ACME framework.
- Learnable standard deviation in the Gaussian distribution - either separate or shared across dimensions.

PiperOrigin-RevId: 333323682
  • Loading branch information
koz4k authored and Copybara-Service committed Sep 24, 2020
1 parent 8b310b9 commit ae02435
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 54 deletions.
16 changes: 7 additions & 9 deletions trax/models/rl.py
Expand Up @@ -78,15 +78,13 @@ def Value(
def ActionInjector(mode):
if inject_actions:
if is_discrete:
encode_layer = tl.Parallel(
tl.Dense(inject_actions_dim),
tl.Embedding(vocab_size, inject_actions_dim)
)
action_encoder = tl.Embedding(vocab_size, inject_actions_dim)
else:
encode_layer = tl.Parallel(
tl.Dense(inject_actions_dim),
tl.Dense(inject_actions_dim),
)
action_encoder = tl.Dense(inject_actions_dim)
encoders = tl.Parallel(
tl.Dense(inject_actions_dim),
action_encoder,
)
if multiplicative_action_injection:
action_injector = tl.Serial(
tl.Fn('TanhMulGate', lambda x, a: x * jnp.tanh(a)),
Expand All @@ -96,7 +94,7 @@ def ActionInjector(mode):
action_injector = tl.Add()
return tl.Serial(
# Input: (body output, actions).
encode_layer,
encoders,
action_injector,
models.PureMLP(
layer_widths=(inject_actions_dim,) * inject_actions_n_layers,
Expand Down
50 changes: 38 additions & 12 deletions trax/rl/actor_critic.py
Expand Up @@ -57,7 +57,8 @@ def __init__(self, task,
n_replay_epochs=1,
scale_value_targets=False,
q_value=False,
q_value_aggregate_max=True,
q_value_aggregate='max',
q_value_temperature=1.0,
q_value_n_samples=1,
**kwargs): # Arguments of PolicyAgent come here.
"""Configures the actor-critic trainer.
Expand All @@ -70,8 +71,9 @@ def __init__(self, task,
value_batch_size: Batch size for value model training.
value_train_steps_per_epoch: Number of steps are we using to train the
value model in each epoch.
value_evals_per_epoch: Number of value trainer evaluations per RL epoch;
only affects metric reporting.
value_evals_per_epoch: Number of value trainer evaluations per RL epoch.
Every evaluation, we also synchronize the weights of the target
network.
value_eval_steps: Number of value trainer steps per evaluation; only
affects metric reporting.
n_shared_layers: Number of layers to share between value and policy
Expand All @@ -86,7 +88,10 @@ def __init__(self, task,
scale_value_targets: If `True`, scale value function targets by
`1 / (1 - gamma)`.
q_value: If `True`, use Q-values as baselines.
q_value_aggregate_max: If `True`, aggregate Q-values with max (or mean).
q_value_aggregate: How to aggregate Q-values. Options: 'mean', 'max',
'softmax', 'logsumexp'.
q_value_temperature: Temperature parameter for the 'softmax' and
'logsumexp' aggregation methods.
q_value_n_samples: Number of samples to average over when calculating
baselines based on Q-values.
**kwargs: Arguments for `PolicyAgent` superclass.
Expand All @@ -112,7 +117,8 @@ def __init__(self, task,
self._value_network_scale = 1

self._q_value = q_value
self._q_value_aggregate_max = q_value_aggregate_max
self._q_value_aggregate = q_value_aggregate
self._q_value_temperature = q_value_temperature
self._q_value_n_samples = q_value_n_samples

is_discrete = isinstance(self._task.action_space, gym.spaces.Discrete)
Expand Down Expand Up @@ -230,14 +236,31 @@ def _run_value_model(self, observations, dist_inputs):
values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim.
return (values, actions, log_probs)

def _aggregate_values(self, values, aggregate_max, act_log_probs):
def _aggregate_values(self, values, aggregate, act_log_probs):
temp = self._q_value_temperature
if self._q_value:
if aggregate_max:
assert values.shape[:2] == (
self._value_batch_size, self._q_value_n_samples
)
if aggregate == 'max':
# max_a Q(s, a)
values = jnp.max(values, axis=1)
elif self._sample_all_discrete_actions:
values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
elif aggregate == 'softmax':
# sum_a (Q(s, a) * w(s, a))
# where w(s, .) = softmax (Q(s, .) / T)
weights = tl.Softmax(axis=1)(values / temp)
values = jnp.sum(values * weights, axis=1)
elif aggregate == 'logsumexp':
# log(mean_a exp(Q(s, a) / T)) * T
n = values.shape[1]
values = (fastmath.logsumexp(values / temp, axis=1) - jnp.log(n)) * temp
else:
values = jnp.mean(values, axis=1)
assert aggregate == 'mean'
# mean_a Q(s, a)
if self._sample_all_discrete_actions:
values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
else:
values = jnp.mean(values, axis=1)
return np.array(values) # Move the values to CPU.

def value_batches_stream(self):
Expand All @@ -254,7 +277,7 @@ def value_batches_stream(self):
np_trajectory.observations, np_trajectory.dist_inputs
)
values = self._aggregate_values(
values, self._q_value_aggregate_max, act_log_probs)
values, self._q_value_aggregate, act_log_probs)

# TODO(pkozakowski): Add some shape assertions and docs.
# Calculate targets based on the advantages over the target network - this
Expand Down Expand Up @@ -311,7 +334,7 @@ def policy_batches_stream(self):
include_final_state=False):
(values, _, act_log_probs) = self._run_value_model(
np_trajectory.observations, np_trajectory.dist_inputs)
values = self._aggregate_values(values, False, act_log_probs)
values = self._aggregate_values(values, 'mean', act_log_probs)
if len(values.shape) != 2:
raise ValueError('Values are expected to have shape ' +
'[batch_size, length], got: %s' % str(values.shape))
Expand Down Expand Up @@ -345,6 +368,9 @@ def train_epoch(self):
self._value_train_steps_per_epoch // self._value_evals_per_epoch,
self._value_eval_steps,
)
# Update the target value network.
self._value_eval_model.weights = self._value_trainer.model_weights
self._value_eval_model.state = self._value_trainer.model_state

# Copy value weights and state to policy trainer.
if self._n_shared_layers > 0:
Expand Down
4 changes: 2 additions & 2 deletions trax/rl/actor_critic_test.py
Expand Up @@ -196,7 +196,7 @@ def test_sampling_awrtrainer_cartpole(self):
advantage_estimator=advantages.monte_carlo,
advantage_normalization=False,
q_value_n_samples=3,
q_value_aggregate_max=True,
q_value_aggregate='max',
reweight=False,
)
trainer.run(1)
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_sampling_awrtrainer_cartpole_sample_all_discrete(self):
advantage_estimator=advantages.monte_carlo,
advantage_normalization=False,
q_value_n_samples=2,
q_value_aggregate_max=True,
q_value_aggregate='max',
reweight=False,
)
trainer.run(1)
Expand Down
4 changes: 2 additions & 2 deletions trax/rl/configs/light_awr_boxing.gin
Expand Up @@ -104,7 +104,7 @@ AWR.scale_value_targets = True
AWR.n_shared_layers = 0
AWR.q_value = True
AWR.q_value_n_samples = %q_value_n_samples
AWR.q_value_aggregate_max = True
AWR.q_value_aggregate = 'max'

# Parameters for SamplingAWR:
# ==============================================================================
Expand Down Expand Up @@ -135,7 +135,7 @@ SamplingAWR.n_replay_epochs = 2
SamplingAWR.scale_value_targets = True
SamplingAWR.n_shared_layers = 0
SamplingAWR.q_value_n_samples = %q_value_n_samples
SamplingAWR.q_value_aggregate_max = True
SamplingAWR.q_value_aggregate = 'max'
SamplingAWR.reweight = False

# Parameters for train_rl:
Expand Down
4 changes: 2 additions & 2 deletions trax/rl/configs/light_cartpole.gin
Expand Up @@ -136,7 +136,7 @@ AWR.scale_value_targets = True
AWR.n_shared_layers = 0
AWR.q_value = True
AWR.q_value_n_samples = %q_value_n_samples
AWR.q_value_aggregate_max = True
AWR.q_value_aggregate = 'max'

# Parameters for SamplingAWR:
# ==============================================================================
Expand Down Expand Up @@ -167,7 +167,7 @@ SamplingAWR.n_replay_epochs = 50
SamplingAWR.scale_value_targets = True
SamplingAWR.n_shared_layers = 0
SamplingAWR.q_value_n_samples = %q_value_n_samples
SamplingAWR.q_value_aggregate_max = True
SamplingAWR.q_value_aggregate = 'max'
SamplingAWR.reweight = False

# Parameters for train_rl:
Expand Down
4 changes: 2 additions & 2 deletions trax/rl/configs/light_mujoco.gin
Expand Up @@ -113,7 +113,7 @@ AWR.scale_value_targets = True
AWR.n_shared_layers = 1
AWR.q_value = True
AWR.q_value_n_samples = %q_value_n_samples
AWR.q_value_aggregate_max = True
AWR.q_value_aggregate = 'max'

# Parameters for SamplingAWR:
# ==============================================================================
Expand Down Expand Up @@ -144,7 +144,7 @@ SamplingAWR.n_replay_epochs = 50
SamplingAWR.scale_value_targets = True
SamplingAWR.n_shared_layers = 1
SamplingAWR.q_value_n_samples = %q_value_n_samples
SamplingAWR.q_value_aggregate_max = True
SamplingAWR.q_value_aggregate = 'max'
SamplingAWR.reweight = False

# Parameters for train_rl:
Expand Down
4 changes: 2 additions & 2 deletions trax/rl/configs/light_mujoco_transformer.gin
Expand Up @@ -101,7 +101,7 @@ AWR.scale_value_targets = True
AWR.n_shared_layers = 1
AWR.q_value = False
AWR.q_value_n_samples = 128
AWR.q_value_aggregate_max = True
AWR.q_value_aggregate = 'max'

# Parameters for SamplingAWR:
# ==============================================================================
Expand Down Expand Up @@ -132,7 +132,7 @@ SamplingAWR.n_replay_epochs = 25
SamplingAWR.scale_value_targets = True
SamplingAWR.n_shared_layers = 1
SamplingAWR.q_value_n_samples = 128
SamplingAWR.q_value_aggregate_max = True
SamplingAWR.q_value_aggregate = 'max'
SamplingAWR.reweight = False

# Parameters for train_rl:
Expand Down
66 changes: 48 additions & 18 deletions trax/rl/distributions.py
Expand Up @@ -109,7 +109,8 @@ def log_prob(self, inputs, point):
axis=[-a for a in range(1, len(self._shape) + 2)],
)

def entropy(self, log_probs):
def entropy(self, inputs):
log_probs = inputs
probs = jnp.exp(log_probs)
return -jnp.sum(probs * log_probs, axis=-1)

Expand All @@ -118,46 +119,75 @@ def entropy(self, log_probs):
class Gaussian(Distribution):
"""Independent multivariate Gaussian distribution parametrized by mean."""

def __init__(self, shape=(), std=1.0):
def __init__(self, shape=(), std=1.0, learn_std=None):
"""Initializes Gaussian distribution.
Args:
shape (tuple): Shape of the sample.
std (float): Standard deviation, shared across the whole sample.
learn_std (str or None): How to learn the standard deviation - 'shared'
to have a single, shared std parameter, or 'separate' to have separate
parameters for each dimension.
"""
self._shape = shape
self._std = std
self._learn_std = learn_std

@property
def n_inputs(self):
def _n_dims(self):
return np.prod(self._shape, dtype=jnp.int32)

def _params(self, inputs):
"""Extracts the mean and std parameters from the inputs."""
assert inputs.shape[-1] == self.n_inputs
n_dims = self._n_dims
# Split the distribution inputs into two parts: mean and std.
mean = inputs[..., :n_dims]
if self._learn_std is not None:
std = inputs[..., n_dims:]
# Std is non-negative, so let's softplus it.
std = tl.Softplus()(std + self._std)
else:
std = self._std
# In case of constant or shared std, upsample it to the same dimensionality
# as the means.
std = jnp.broadcast_to(std, mean.shape)
return (mean, std)

@property
def n_inputs(self):
n_dims = self._n_dims
return {
None: n_dims,
'shared': n_dims + 1,
'separate': n_dims * 2,
}[self._learn_std]

def sample(self, inputs, temperature=1.0):
(mean, std) = self._params(inputs)
mean = jnp.reshape(mean, mean.shape[:-1] + self._shape)
std = jnp.reshape(std, std.shape[:-1] + self._shape)
if temperature == 0:
# this seemingly strange if solves the problem
# of calling np/jnp.random in the metric PreferredMove
return inputs
return mean
else:
return np.random.normal(
loc=jnp.reshape(inputs, inputs.shape[:-1] + self._shape),
scale=self._std * temperature,
)
return np.random.normal(loc=mean, scale=(std * temperature))

def log_prob(self, inputs, point):
point = point.reshape(inputs.shape[:-1] + (-1,))
return (
# L2 term.
-jnp.sum((point - inputs) ** 2, axis=-1) / (2 * self._std ** 2) -
(mean, std) = self._params(inputs)
return -jnp.sum(
# Scaled distance.
(point - mean) ** 2 / (2 * std ** 2) +
# Normalizing constant.
((jnp.log(self._std) + jnp.log(jnp.sqrt(2 * jnp.pi)))
* np.prod(self._shape))
(jnp.log(std) + jnp.log(jnp.sqrt(2 * jnp.pi))),
axis=-1,
)

# At that point self._std is not learnable, hence
# we return a constant
def entropy(self, log_probs):
del log_probs # would be helpful if self._std was learnable
return jnp.exp(self._std) + .5 * jnp.log(2.0 * jnp.pi * jnp.e)
def entropy(self, inputs):
(_, std) = self._params(inputs)
return jnp.sum(jnp.exp(std) + .5 * jnp.log(2.0 * jnp.pi * jnp.e), axis=-1)


# TODO(pkozakowski): Implement GaussianMixture.
Expand Down
25 changes: 20 additions & 5 deletions trax/rl/distributions_test.py
Expand Up @@ -14,10 +14,11 @@
# limitations under the License.

# Lint as: python3
"""Tests for initializers."""
"""Tests for trax.rl.distributions."""

from absl.testing import absltest
from absl.testing import parameterized
import gin
import gym
import numpy as np

Expand All @@ -27,11 +28,25 @@
class DistributionsTest(parameterized.TestCase):

@parameterized.named_parameters(
('discrete', gym.spaces.Discrete(n=4)),
('multi_discrete', gym.spaces.MultiDiscrete(nvec=[5, 5])),
('gaussian', gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5))),
('discrete', gym.spaces.Discrete(n=4), ''),
('multi_discrete', gym.spaces.MultiDiscrete(nvec=[5, 5]), ''),
(
'gaussian_const_std',
gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)),
'Gaussian.learn_std = None',
), (
'gaussian_shared_std',
gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)),
'Gaussian.learn_std = "shared"',
), (
'gaussian_separate_std',
gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)),
'Gaussian.learn_std = "separate"',
),
)
def test_shapes(self, space):
def test_shapes(self, space, gin_config):
gin.parse_config(gin_config)

batch_shape = (2, 3)
distribution = distributions.create_distribution(space)
inputs = np.random.random(batch_shape + (distribution.n_inputs,))
Expand Down
12 changes: 12 additions & 0 deletions trax/rl/normalization.py
Expand Up @@ -111,3 +111,15 @@ def forward(self, inputs):
norm_observations = (observations - mean) / (var ** 0.5 + self._epsilon)
self.state = state
return norm_observations


@gin.configurable(blacklist=['mode'])
def LayerNormSquash(mode, width=128): # pylint: disable=invalid-name
"""Dense-LayerNorm-Tanh normalizer inspired by ACME."""
# https://github.com/deepmind/acme/blob/master/acme/jax/networks/continuous.py#L34
del mode
return tl.Serial([
tl.Dense(width),
tl.LayerNorm(),
tl.Tanh(),
])

0 comments on commit ae02435

Please sign in to comment.