Skip to content

Commit

Permalink
[RLlib] APPO TF with RLModule and Learner API (ray-project#33310)
Browse files Browse the repository at this point in the history
Signed-off-by: Avnish <avnishnarayan@gmail.com>
Signed-off-by: Jonathan Carter <jonathan.carter@magd.ox.ac.uk>
  • Loading branch information
avnishn authored and joncarter1 committed Apr 2, 2023
1 parent 181cb61 commit 933e187
Show file tree
Hide file tree
Showing 28 changed files with 1,025 additions and 99 deletions.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,13 @@ py_test(
srcs = ["algorithms/appo/tests/test_appo_off_policyness.py"]
)

py_test(
name = "test_appo_learner",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
srcs = ["algorithms/appo/tests/tf/test_appo_learner.py"]
)

# ARS
py_test(
name = "test_ars",
Expand Down
212 changes: 169 additions & 43 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.impala.impala import Impala, ImpalaConfig
from ray.rllib.algorithms.appo.tf.appo_tf_learner import AppoHPs, LEARNER_RESULTS_KL_KEY
from ray.rllib.algorithms.ppo.ppo import UpdateKL
from ray.rllib.execution.common import _get_shared_metrics, STEPS_SAMPLED_COUNTER
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
NUM_TARGET_UPDATES,
NUM_ENV_STEPS_TRAINED,
NUM_AGENT_STEPS_TRAINED,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.metrics import ALL_MODULES, LEARNER_STATS_KEY
from ray.rllib.utils.typing import (
ResultDict,
)
Expand Down Expand Up @@ -74,6 +78,7 @@ def __init__(self, algo_class=None):
# __sphinx_doc_begin__

# APPO specific settings:
self._learner_hps = AppoHPs()
self.vtrace = True
self.use_critic = True
self.use_gae = True
Expand All @@ -92,6 +97,7 @@ def __init__(self, algo_class=None):
self.num_multi_gpu_tower_stacks = 1
self.minibatch_buffer_size = 1
self.num_sgd_iter = 1
self.target_update_frequency = 1
self.replay_proportion = 0.0
self.replay_buffer_num_slots = 100
self.learner_queue_size = 16
Expand All @@ -108,6 +114,8 @@ def __init__(self, algo_class=None):
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.01
self.entropy_coeff_schedule = None
self.tau = 1.0

# __sphinx_doc_end__
# fmt: on

Expand All @@ -123,6 +131,8 @@ def training(
use_kl_loss: Optional[bool] = NotProvided,
kl_coeff: Optional[float] = NotProvided,
kl_target: Optional[float] = NotProvided,
tau: Optional[float] = NotProvided,
target_update_frequency: Optional[int] = NotProvided,
**kwargs,
) -> "APPOConfig":
"""Sets the training related configuration.
Expand All @@ -141,6 +151,19 @@ def training(
kl_coeff: Coefficient for weighting the KL-loss term.
kl_target: Target term for the KL-term to reach (via adjusting the
`kl_coeff` automatically).
tau: The factor by which to update the target policy network towards
the current policy network. Can range between 0 and 1.
e.g. updated_param = tau * current_param + (1 - tau) * target_param
target_update_frequency: The frequency to update the target policy and
tune the kl loss coefficients that are used during training. After
setting this parameter, the algorithm waits for at least
`target_update_frequency * minibatch_size * num_sgd_iter` number of
samples to be trained on by the learner group before updating the target
networks and tuned the kl loss coefficients that are used during
training.
NOTE: this parameter is only applicable when using the learner api
(_enable_learner_api=True and _enable_rl_module_api=True).
Returns:
This updated AlgorithmConfig object.
Expand All @@ -158,15 +181,52 @@ def training(
self.lambda_ = lambda_
if clip_param is not NotProvided:
self.clip_param = clip_param
self._learner_hps.clip_param = clip_param
if use_kl_loss is not NotProvided:
self.use_kl_loss = use_kl_loss
if kl_coeff is not NotProvided:
self.kl_coeff = kl_coeff
self._learner_hps.kl_coeff = kl_coeff
if kl_target is not NotProvided:
self.kl_target = kl_target
self._learner_hps.kl_target = kl_target
if tau is not NotProvided:
self.tau = tau
self._learner_hps.tau = tau
if target_update_frequency is not NotProvided:
self.target_update_frequency = target_update_frequency

return self

@override(AlgorithmConfig)
def get_default_learner_class(self):
if self.framework_str == "tf2":
from ray.rllib.algorithms.appo.tf.appo_tf_learner import APPOTfLearner

return APPOTfLearner
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")

@override(AlgorithmConfig)
def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec:
if self.framework_str == "tf2":
from ray.rllib.algorithms.appo.appo_catalog import APPOCatalog
from ray.rllib.algorithms.appo.tf.appo_tf_rl_module import APPOTfRLModule

return SingleAgentRLModuleSpec(
module_class=APPOTfRLModule, catalog_class=APPOCatalog
)
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")

@override(ImpalaConfig)
def validate(self) -> None:
super().validate()
self._learner_hps.tau = self.tau
self._learner_hps.kl_target = self.kl_target
self._learner_hps.kl_coeff = self.kl_coeff
self._learner_hps.clip_param = self.clip_param


class UpdateTargetAndKL:
def __init__(self, workers, config):
Expand Down Expand Up @@ -199,15 +259,23 @@ def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)

# After init: Initialize target net.
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)

# TODO(avnishn):
# does this need to happen in __init__? I think we can move it to setup()
if not self.config._enable_rl_module_api:
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)

@override(Impala)
def setup(self, config: AlgorithmConfig):
super().setup(config)

self.update_kl = UpdateKL(self.workers)
# TODO(avnishn):
# this attribute isn't used anywhere else in the code. I think we can safely
# delete it.
if not self.config._enable_rl_module_api:
self.update_kl = UpdateKL(self.workers)

def after_train_step(self, train_results: ResultDict) -> None:
"""Updates the target network and the KL coefficient for the APPO-loss.
Expand All @@ -222,45 +290,84 @@ def after_train_step(self, train_results: ResultDict) -> None:
train_results: The results dict collected during the most recent
training step.
"""
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
]

last_update = self._counters[LAST_TARGET_UPDATE_TS]
target_update_freq = (
self.config.num_sgd_iter * self.config.minibatch_buffer_size
)
if cur_ts - last_update > target_update_freq:
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update our target network.
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
if self.config._enable_learner_api and train_results:
# using steps trained here instead of sampled ... I'm not sure why the
# other implemenetation uses sampled.
# to be quite frank, im not sure if I understand how their target update
# freq would work. The difference in steps sampled/trained is pretty
# much always going to be larger than self.config.num_sgd_iter *
# self.config.minibatch_buffer_size unless the number of steps collected
# is really small. The thing is that the default rollout fragment length
# is 50, so the minibatch buffer size * num_sgd_iter is going to be
# have to be 50 to even meet the threshold of having delayed target
# updates.
# we should instead have the target / kl threshold update be based off
# of the train_batch_size * some target update frequency * num_sgd_iter.
cur_ts = self._counters[
NUM_ENV_STEPS_TRAINED
if self.config.count_steps_by == "env_steps"
else NUM_AGENT_STEPS_TRAINED
]
target_update_steps_freq = (
self.config.train_batch_size
* self.config.num_sgd_iter
* self.config.target_update_frequency
)
if (cur_ts - last_update) >= target_update_steps_freq:
kls_to_update = {}
for module_id, module_results in train_results.items():
if module_id != ALL_MODULES:
kls_to_update[module_id] = module_results[LEARNER_STATS_KEY][
LEARNER_RESULTS_KL_KEY
]
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
self.learner_group.additional_update(sampled_kls=kls_to_update)

# Also update the KL-coefficient for the APPO loss, if necessary.
if self.config.use_kl_loss:

def update(pi, pi_id):
assert LEARNER_STATS_KEY not in train_results, (
"{} should be nested under policy id key".format(
LEARNER_STATS_KEY
),
train_results,
)
if pi_id in train_results:
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
assert kl is not None, (train_results, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(kl)
else:
logger.warning("No data for {}, not updating kl".format(pi_id))

# Update KL on all trainable policies within the local (trainer)
# Worker.
self.workers.local_worker().foreach_policy_to_train(update)
else:
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
]
target_update_freq = (
self.config.num_sgd_iter * self.config.minibatch_buffer_size
)
if cur_ts - last_update > target_update_freq:
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update our target network.
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)

# Also update the KL-coefficient for the APPO loss, if necessary.
if self.config.use_kl_loss:

def update(pi, pi_id):
assert LEARNER_STATS_KEY not in train_results, (
"{} should be nested under policy id key".format(
LEARNER_STATS_KEY
),
train_results,
)
if pi_id in train_results:
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
assert kl is not None, (train_results, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(kl)
else:
logger.warning(
"No data for {}, not updating kl".format(pi_id)
)

# Update KL on all trainable policies within the local (trainer)
# Worker.
self.workers.local_worker().foreach_policy_to_train(update)

@override(Impala)
def training_step(self) -> ResultDict:
Expand All @@ -282,14 +389,33 @@ def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy

return APPOTorchPolicy
if config._enable_rl_module_api:
raise ValueError(
"APPO with the torch backend is not yet supported by "
" the RLModule and Learner API."
)
else:
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy

return APPOTorchPolicy
elif config["framework"] == "tf":
if config._enable_rl_module_api:
raise ValueError(
"RLlib's RLModule and Learner API is not supported for"
" tf1. Use "
"framework='tf2' instead."
)
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy

return APPOTF1Policy
else:
if config._enable_rl_module_api:
# TODO(avnishn): This policy class doesn't work just yet
from ray.rllib.algorithms.appo.tf.appo_tf_policy_rlm import (
APPOTfPolicyWithRLModule,
)

return APPOTfPolicyWithRLModule
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy

return APPOTF2Policy
24 changes: 24 additions & 0 deletions rllib/algorithms/appo/appo_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog


class APPOCatalog(PPOCatalog):
"""The Catalog class used to build models for APPO.
PPOCatalog provides the following models:
- ActorCriticEncoder: The encoder used to encode the observations.
- Pi Head: The head used to compute the policy logits.
- Value Function Head: The head used to compute the value function.
The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
for the policy and value function. See implementations of PPORLModuleBase for
more details.
Any custom ActorCriticEncoder can be built by overriding the
build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig
at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom
ActorCriticEncoder during RLModule runtime.
Any custom head can be built by overriding the build_pi_head() and build_vf_head()
methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to
build custom heads during RLModule runtime.
"""
Empty file.
Loading

0 comments on commit 933e187

Please sign in to comment.