Skip to content

Commit

Permalink
[RLlib] Fix issues with action masking examples. (ray-project#38095)
Browse files Browse the repository at this point in the history
Signed-off-by: harborn <gangsheng.wu@intel.com>
  • Loading branch information
ArturNiederfahrenhorst authored and harborn committed Aug 17, 2023
1 parent 8ad5b15 commit 3dd834b
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 144 deletions.
4 changes: 2 additions & 2 deletions rllib/BUILD
Expand Up @@ -2977,12 +2977,12 @@ py_test(
# --------------------------------------------------------------------

py_test(
name = "examples/action_masking_tf",
name = "examples/action_masking_tf2",
main = "examples/action_masking.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
srcs = ["examples/action_masking.py"],
args = ["--stop-iter=2", "--framework=tf"]
args = ["--stop-iter=2", "--framework=tf2"]
)

py_test(
Expand Down
137 changes: 45 additions & 92 deletions rllib/examples/action_masking.py
Expand Up @@ -41,31 +41,21 @@

from gymnasium.spaces import Box, Discrete
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.examples.env.action_mask_env import ActionMaskEnv
from ray.rllib.examples.models.action_mask_model import (
ActionMaskModel,
TorchActionMaskModel,
from ray.rllib.examples.rl_module.action_masking_rlm import (
TorchActionMaskRLM,
TFActionMaskRLM,
)
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec

from ray.tune.logger import pretty_print


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()

# example-specific args
parser.add_argument(
"--no-masking",
action="store_true",
help="Do NOT mask invalid actions. This will likely lead to errors.",
)

# general args
parser.add_argument(
"--run", type=str, default="APPO", help="The RLlib-registered algorithm to use."
)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
Expand All @@ -76,24 +66,6 @@ def get_cli_args():
parser.add_argument(
"--stop-iters", type=int, default=10, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps",
type=int,
default=10000,
help="Number of timesteps to train.",
)
parser.add_argument(
"--stop-reward",
type=float,
default=80.0,
help="Reward at which we stop training.",
)
parser.add_argument(
"--no-tune",
action="store_true",
help="Run without Tune using a manual train loop instead. Here,"
"there is no TensorBoard support.",
)
parser.add_argument(
"--local-mode",
action="store_true",
Expand All @@ -110,6 +82,15 @@ def get_cli_args():

ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)

if args.framework == "torch":
rlm_class = TorchActionMaskRLM
elif args.framework == "tf2":
rlm_class = TFActionMaskRLM
else:
raise ValueError(f"Unsupported framework: {args.framework}")

rlm_spec = SingleAgentRLModuleSpec(module_class=rlm_class)

# main part: configure the ActionMaskEnv and ActionMaskModel
config = (
ppo.PPOConfig()
Expand All @@ -119,75 +100,47 @@ def get_cli_args():
ActionMaskEnv,
env_config={
"action_space": Discrete(100),
# This is not going to be the observation space that our RLModule sees.
# It's only the configuration provided to the environment.
# The environment will instead create Dict observations with
# the keys "observations" and "action_mask".
"observation_space": Box(-1.0, 1.0, (5,)),
},
)
.training(
# the ActionMaskModel retrieves the invalid actions and avoids them
model={
"custom_model": ActionMaskModel
if args.framework != "torch"
else TorchActionMaskModel,
# disable action masking according to CLI
"custom_model_config": {"no_masking": args.no_masking},
},
)
# We need to disable preprocessing of observations, because preprocessing
# would flatten the observation dict of the environment.
.experimental(_disable_preprocessor_api=True)
.framework(args.framework)
.resources(
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))
)
.rl_module(rl_module_spec=rlm_spec)
)

stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}

# manual training loop (no Ray tune)
if args.no_tune:
if args.run not in {"APPO", "PPO"}:
raise ValueError("This example only supports APPO and PPO.")

algo = config.build()

# run manual training loop and print results after each iteration
for _ in range(args.stop_iters):
result = algo.train()
print(pretty_print(result))
# stop training if the target train steps or reward are reached
if (
result["timesteps_total"] >= args.stop_timesteps
or result["episode_reward_mean"] >= args.stop_reward
):
break

# manual test loop
print("Finished training. Running manual test/inference loop.")
# prepare environment with max 10 steps
config["env_config"]["max_episode_len"] = 10
env = ActionMaskEnv(config["env_config"])
obs, info = env.reset()
done = False
# run one iteration until done
print(f"ActionMaskEnv with {config['env_config']}")
while not done:
action = algo.compute_single_action(obs)
next_obs, reward, done, truncated, _ = env.step(action)
# observations contain original observations and the action mask
# reward is random and irrelevant here and therefore not printed
print(f"Obs: {obs}, Action: {action}")
obs = next_obs

# Run with tune for auto Algorithm creation, stopping, TensorBoard, etc.
else:
tuner = tune.Tuner(
args.run,
param_space=config.to_dict(),
run_config=air.RunConfig(stop=stop, verbose=2),
)
tuner.fit()
algo = config.build()

# run manual training loop and print results after each iteration
for _ in range(args.stop_iters):
result = algo.train()
print(pretty_print(result))

# manual test loop
print("Finished training. Running manual test/inference loop.")
# prepare environment with max 10 steps
config["env_config"]["max_episode_len"] = 10
env = ActionMaskEnv(config["env_config"])
obs, info = env.reset()
done = False
# run one iteration until done
print(f"ActionMaskEnv with {config['env_config']}")
while not done:
action = algo.compute_single_action(obs)
next_obs, reward, done, truncated, _ = env.step(action)
# observations contain original observations and the action mask
# reward is random and irrelevant here and therefore not printed
print(f"Obs: {obs}, Action: {action}")
obs = next_obs

print("Finished successfully without selecting invalid actions.")
ray.shutdown()
121 changes: 121 additions & 0 deletions rllib/examples/rl_module/action_masking_rlm.py
@@ -0,0 +1,121 @@
import gymnasium as gym

from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch, try_import_tf
from ray.rllib.utils.torch_utils import FLOAT_MIN

torch, nn = try_import_torch()
_, tf, _ = try_import_tf()


class ActionMaskRLMBase(RLModule):
def __init__(self, config: RLModuleConfig):
if not isinstance(config.observation_space, gym.spaces.Dict):
raise ValueError(
"This model requires the environment to provide a "
"gym.spaces.Dict observation space."
)
# We need to adjust the observation space for this RL Module so that, when
# building the default models, the RLModule does not "see" the action mask but
# only the original observation space without the action mask. This tricks it
# into building models that are compatible with the original observation space.
config.observation_space = config.observation_space["observations"]

# The PPORLModule, in its constructor, will build models for the modified
# observation space.
super().__init__(config)


class TorchActionMaskRLM(ActionMaskRLMBase, PPOTorchRLModule):
def _forward_inference(self, batch, **kwargs):
return mask_forward_fn_torch(super()._forward_inference, batch, **kwargs)

def _forward_train(self, batch, *args, **kwargs):
return mask_forward_fn_torch(super()._forward_train, batch, **kwargs)

def _forward_exploration(self, batch, *args, **kwargs):
return mask_forward_fn_torch(super()._forward_exploration, batch, **kwargs)


class TFActionMaskRLM(ActionMaskRLMBase, PPOTfRLModule):
def _forward_inference(self, batch, **kwargs):
return mask_forward_fn_tf(super()._forward_inference, batch, **kwargs)

def _forward_train(self, batch, *args, **kwargs):
return mask_forward_fn_tf(super()._forward_train, batch, **kwargs)

def _forward_exploration(self, batch, *args, **kwargs):
return mask_forward_fn_tf(super()._forward_exploration, batch, **kwargs)


def mask_forward_fn_torch(forward_fn, batch, **kwargs):
_check_batch(batch)

# Extract the available actions tensor from the observation.
action_mask = batch[SampleBatch.OBS]["action_mask"]

# Modify the incoming batch so that the default models can compute logits and
# values as usual.
batch[SampleBatch.OBS] = batch[SampleBatch.OBS]["observations"]

outputs = forward_fn(batch, **kwargs)

# Mask logits
logits = outputs[SampleBatch.ACTION_DIST_INPUTS]
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask

# Replace original values with masked values.
outputs[SampleBatch.ACTION_DIST_INPUTS] = masked_logits

return outputs


def mask_forward_fn_tf(forward_fn, batch, **kwargs):
_check_batch(batch)

# Extract the available actions tensor from the observation.
action_mask = batch[SampleBatch.OBS]["action_mask"]

# Modify the incoming batch so that the default models can compute logits and
# values as usual.
batch[SampleBatch.OBS] = batch[SampleBatch.OBS]["observations"]

outputs = forward_fn(batch, **kwargs)

# Mask logits
logits = outputs[SampleBatch.ACTION_DIST_INPUTS]
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
masked_logits = logits + inf_mask

# Replace original values with masked values.
outputs[SampleBatch.ACTION_DIST_INPUTS] = masked_logits

return outputs


def _check_batch(batch):
"""Check whether the batch contains the required keys."""
if "action_mask" not in batch[SampleBatch.OBS]:
raise ValueError(
"Action mask not found in observation. This model requires "
"the environment to provide observations that include an "
"action mask (i.e. an observation space of the Dict space "
"type that looks as follows: \n"
"{'action_mask': Box(0.0, 1.0, shape=(self.action_space.n,)),"
"'observations': <observation_space>}"
)
if "observations" not in batch[SampleBatch.OBS]:
raise ValueError(
"Observations not found in observation.This model requires "
"the environment to provide observations that include a "
" (i.e. an observation space of the Dict space "
"type that looks as follows: \n"
"{'action_mask': Box(0.0, 1.0, shape=(self.action_space.n,)),"
"'observations': <observation_space>}"
)
10 changes: 0 additions & 10 deletions rllib/models/tests/test_distributions.py
Expand Up @@ -129,16 +129,6 @@ def test_categorical(self):
expected = (probs * (probs / probs2).log()).sum(dim=-1)
check(dist_with_probs.kl(dist2), expected)

# test temperature
dist_with_logits = TorchCategorical(logits=logits, temperature=1e-20)
samples = dist_with_logits.sample()
rsamples = dist_with_logits.rsample()
# expected is armax of logits
expected = logits.argmax(dim=-1)
check(samples, expected)
# rsample should be the same as sample, but one-hot encoded
check(samples, torch.argmax(rsamples, dim=-1))

def test_diag_gaussian(self):
batch_size = 128
ndim = 4
Expand Down

0 comments on commit 3dd834b

Please sign in to comment.