Skip to content

Commit

Permalink
[RLlib] Redo fix bug normalize vs unsquash actions (original PR made …
Browse files Browse the repository at this point in the history
…log-likelihood test flakey). (ray-project#17014)
  • Loading branch information
sven1977 committed Jul 13, 2021
1 parent 16f1011 commit 1fd0eb8
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 47 deletions.
19 changes: 15 additions & 4 deletions rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,21 @@
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations
"compress_observations": False,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).

# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
"training_intensity": None,

# === Optimization ===
Expand Down
19 changes: 15 additions & 4 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,21 @@
"compress_observations": False,
# Callback to run before learning on a multi-agent batch of experiences.
"before_learn_on_batch": None,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).

# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
"training_intensity": None,

# === Optimization ===
Expand Down
19 changes: 15 additions & 4 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,21 @@
"final_prioritized_replay_beta": 0.4,
# Whether to LZ4 compress observations
"compress_observations": False,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).

# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
"training_intensity": None,

# === Optimization ===
Expand Down
8 changes: 4 additions & 4 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ def compute_single_action(
policy_id: PolicyID = DEFAULT_POLICY_ID,
full_fetch: bool = False,
explore: bool = None,
normalize_actions: Optional[bool] = None,
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
) -> TensorStructType:
"""Computes an action for the specified policy on the local Worker.
Expand All @@ -968,8 +968,8 @@ def compute_single_action(
This is always set to True if RNN state is specified.
explore (bool): Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
normalize_actions (bool): Should actions be normalized according to
the env's/Policy's action space?
unsquash_actions (bool): Should actions be unsquashed according to
the env's/Policy's action space?
clip_actions (bool): Should actions be clipped according to the
env's/Policy's action space?
Expand All @@ -993,7 +993,7 @@ def compute_single_action(
prev_action,
prev_reward,
info,
normalize_actions=normalize_actions,
unsquash_actions=unsquash_actions,
clip_actions=clip_actions,
explore=explore)

Expand Down
4 changes: 3 additions & 1 deletion rllib/offline/off_policy_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def action_prob(self, batch: SampleBatchType) -> np.ndarray:
obs_batch=batch[SampleBatch.CUR_OBS],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS))
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
actions_normalized=True,
)
log_likelihoods = convert_to_numpy(log_likelihoods)
return np.exp(log_likelihoods)

Expand Down
21 changes: 15 additions & 6 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.typing import TensorType

Expand Down Expand Up @@ -565,12 +566,15 @@ def _compute_action_helper(self, input_dict, state_batches, episodes,

@with_lock
@override(Policy)
def compute_log_likelihoods(self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None):
def compute_log_likelihoods(
self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
actions_normalized=True,
):
if action_sampler_fn and action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
"`action_distribution_fn` and a provided "
Expand Down Expand Up @@ -606,6 +610,11 @@ def compute_log_likelihoods(self,
dist_class = self.dist_class

action_dist = dist_class(dist_inputs, self.model)

# Normalize actions if necessary.
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)

log_likelihoods = action_dist.logp(actions)

return log_likelihoods
Expand Down
31 changes: 21 additions & 10 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import clip_action, \
get_base_struct_from_space, normalize_action, unbatch
get_base_struct_from_space, unbatch, unsquash_action
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
TensorType, TrainerConfigDict, Tuple, Union

Expand Down Expand Up @@ -161,7 +161,7 @@ def compute_single_action(
clip_actions: bool = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
normalize_actions: bool = None,
unsquash_actions: bool = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Unbatched version of compute_actions.
Expand All @@ -176,7 +176,7 @@ def compute_single_action(
episode (Optional[MultiAgentEpisode]): this provides access to all
of the internal episode state, which may be useful for
model-based or multi-agent algorithms.
normalize_actions (bool): Should actions be normalized according to
unsquash_actions (bool): Should actions be unsquashed according to
the Policy's action space?
clip_actions (bool): Should actions be clipped according to the
Policy's action space?
Expand All @@ -195,8 +195,10 @@ def compute_single_action(
if any.
- info (dict): Dictionary of extra features, if any.
"""
normalize_actions = \
normalize_actions if normalize_actions is not None \
# If policy works in normalized space, we should unsquash the action.
# Use value of config.normalize_actions, if None.
unsquash_actions = \
unsquash_actions if unsquash_actions is not None \
else self.config["normalize_actions"]
clip_actions = clip_actions if clip_actions is not None else \
self.config["clip_actions"]
Expand Down Expand Up @@ -244,9 +246,12 @@ def compute_single_action(
assert len(single_action) == 1
single_action = single_action[0]

if normalize_actions:
single_action = normalize_action(single_action,
self.action_space_struct)
# If we work in normalized action space (normalize_actions=True),
# we re-translate here into the env's action space.
if unsquash_actions:
single_action = unsquash_action(single_action,
self.action_space_struct)
# Clip, according to env's action space.
elif clip_actions:
single_action = clip_action(single_action,
self.action_space_struct)
Expand Down Expand Up @@ -314,8 +319,10 @@ def compute_log_likelihoods(
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:
"""Computes the log-prob/likelihood for a given action and observation.
Args:
Expand All @@ -330,6 +337,10 @@ def compute_log_likelihoods(
Batch of previous action values.
prev_reward_batch (Optional[Union[List[TensorType], TensorType]]):
Batch of previous rewards.
actions_normalized (bool): Is the given `actions` already
normalized (between -1.0 and 1.0) or not? If not and
`normalize_actions=True`, we need to normalize the given
actions first, before calculating log likelihoods.
Returns:
TensorType: Batch of log probs/likelihoods, with shape:
Expand Down
20 changes: 17 additions & 3 deletions rllib/policy/tests/test_compute_log_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def do_test_log_likelihood(run,
obs_batch[0],
prev_action=prev_a,
prev_reward=prev_r,
explore=True))
explore=True,
# Do not unsquash actions
# (remain in normalized [-1.0; 1.0] space).
unsquash_actions=False,
))

# Test all taken actions for their log-likelihoods vs expected values.
if continuous:
Expand Down Expand Up @@ -89,7 +93,9 @@ def do_test_log_likelihood(run,
np.array([a]),
preprocessed_obs_batch,
prev_action_batch=np.array([prev_a]) if prev_a else None,
prev_reward_batch=np.array([prev_r]) if prev_r else None)
prev_reward_batch=np.array([prev_r]) if prev_r else None,
actions_normalized=True,
)
check(logp, expected_logp[0], rtol=0.2)
# Test all available actions for their logp values.
else:
Expand Down Expand Up @@ -118,11 +124,13 @@ def test_dqn(self):
config = dqn.DEFAULT_CONFIG.copy()
# Soft-Q for DQN.
config["exploration_config"] = {"type": "SoftQ", "temperature": 0.5}
config["seed"] = 42
do_test_log_likelihood(dqn.DQNTrainer, config)

def test_pg_cont(self):
"""Tests PG's (cont. actions) compute_log_likelihoods method."""
config = pg.DEFAULT_CONFIG.copy()
config["seed"] = 42
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
prev_a = np.array([0.0])
Expand All @@ -136,25 +144,30 @@ def test_pg_cont(self):
def test_pg_discr(self):
"""Tests PG's (cont. actions) compute_log_likelihoods method."""
config = pg.DEFAULT_CONFIG.copy()
config["seed"] = 42
prev_a = np.array(0)
do_test_log_likelihood(pg.PGTrainer, config, prev_a)

def test_ppo_cont(self):
"""Tests PPO's (cont. actions) compute_log_likelihoods method."""
config = ppo.DEFAULT_CONFIG.copy()
config["seed"] = 42
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
prev_a = np.array([0.0])
do_test_log_likelihood(ppo.PPOTrainer, config, prev_a, continuous=True)

def test_ppo_discr(self):
"""Tests PPO's (discr. actions) compute_log_likelihoods method."""
config = ppo.DEFAULT_CONFIG.copy()
config["seed"] = 42
prev_a = np.array(0)
do_test_log_likelihood(ppo.PPOTrainer, ppo.DEFAULT_CONFIG, prev_a)
do_test_log_likelihood(ppo.PPOTrainer, config, prev_a)

def test_sac_cont(self):
"""Tests SAC's (cont. actions) compute_log_likelihoods method."""
config = sac.DEFAULT_CONFIG.copy()
config["seed"] = 42
config["policy_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_activation"] = "linear"
prev_a = np.array([0.0])
Expand Down Expand Up @@ -184,6 +197,7 @@ def logp_func(means, log_stds, values, low=-1.0, high=1.0):
def test_sac_discr(self):
"""Tests SAC's (discrete actions) compute_log_likelihoods method."""
config = sac.DEFAULT_CONFIG.copy()
config["seed"] = 42
config["policy_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_activation"] = "linear"
prev_a = np.array(0)
Expand Down
12 changes: 10 additions & 2 deletions rllib/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf, get_variable
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import ModelGradients, TensorType, \
TrainerConfigDict
Expand Down Expand Up @@ -399,8 +400,10 @@ def compute_log_likelihoods(
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:

if self._log_likelihood is None:
raise ValueError("Cannot compute log-prob/likelihood w/o a "
Expand All @@ -411,6 +414,11 @@ def compute_log_likelihoods(
explore=False, tf_sess=self.get_session())

builder = TFRunBuilder(self._sess, "compute_log_likelihoods")

# Normalize actions if necessary.
if actions_normalized is False and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)

# Feed actions (for which we want logp values) into graph.
builder.add_feed_dict({self._action_input: actions})
# Feed observations.
Expand Down
15 changes: 12 additions & 3 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
Expand Down Expand Up @@ -373,8 +374,10 @@ def compute_log_likelihoods(
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:

if self.action_sampler_fn and self.action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
Expand Down Expand Up @@ -436,7 +439,13 @@ def compute_log_likelihoods(
seq_lens)

action_dist = dist_class(dist_inputs, self.model)
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])

# Normalize actions if necessary.
actions = input_dict[SampleBatch.ACTIONS]
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)

log_likelihoods = action_dist.logp(actions)

return log_likelihoods

Expand Down
Loading

0 comments on commit 1fd0eb8

Please sign in to comment.