From 480cd26b52cdaa580deeb6a90ca644d1af71377a Mon Sep 17 00:00:00 2001 From: Victor Mao Date: Fri, 28 Jul 2023 12:52:09 -0400 Subject: [PATCH] Update episode_v2.py with last_info_for (#37382) Adding last_info_for method to return last info dict from the environment. This is currently present in Episode but not EpisodeV2 and is especially useful for custom environment values. Signed-off-by: Victor Mao Co-authored-by: Artur Niederfahrenhorst Signed-off-by: e428265 --- rllib/evaluation/episode_v2.py | 7 ++- .../examples/custom_metrics_and_callbacks.py | 52 +++++++++++++------ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/rllib/evaluation/episode_v2.py b/rllib/evaluation/episode_v2.py index fd5704c609652..969137e53b387 100644 --- a/rllib/evaluation/episode_v2.py +++ b/rllib/evaluation/episode_v2.py @@ -12,7 +12,7 @@ from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.typing import AgentID, EnvID, PolicyID, TensorType +from ray.rllib.utils.typing import AgentID, EnvID, EnvInfoDict, PolicyID, TensorType if TYPE_CHECKING: from ray.rllib.algorithms.callbacks import DefaultCallbacks @@ -370,6 +370,11 @@ def is_truncated(self, agent_id: AgentID) -> bool: def set_last_info(self, agent_id: AgentID, info: Dict): self._last_infos[agent_id] = info + def last_info_for( + self, agent_id: AgentID = _DUMMY_AGENT_ID + ) -> Optional[EnvInfoDict]: + return self._last_infos.get(agent_id) + @property def length(self): return self.total_env_steps diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index 7af974609cf75..7be0d9e2aca5d 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -9,6 +9,7 @@ from typing import Dict, Tuple import argparse +import gymnasium as gym import numpy as np import os @@ -31,6 +32,31 @@ parser.add_argument("--stop-iters", type=int, default=2000) +# Create a custom CartPole environment that maintains an estimate of velocity +class CustomCartPole(gym.Env): + def __init__(self, config): + self.env = gym.make("CartPole-v1") + self.observation_space = self.env.observation_space + self.action_space = self.env.action_space + self._pole_angle_vel = 0.0 + self.last_angle = 0.0 + + def reset(self, *, seed=None, options=None): + self._pole_angle_vel = 0.0 + obs, info = self.env.reset() + self.last_angle = obs[2] + return obs, info + + def step(self, action): + obs, rew, term, trunc, info = self.env.step(action) + angle = obs[2] + self._pole_angle_vel = ( + 0.5 * (angle - self.last_angle) + 0.5 * self._pole_angle_vel + ) + info["pole_angle_vel"] = self._pole_angle_vel + return obs, rew, term, trunc, info + + class MyCallbacks(DefaultCallbacks): def on_episode_start( self, @@ -48,7 +74,7 @@ def on_episode_start( "ERROR: `on_episode_start()` callback should be called right " "after env reset!" ) - print("episode {} (env-idx={}) started.".format(episode.episode_id, env_index)) + # Create lists to store angles in episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = [] @@ -72,6 +98,11 @@ def on_episode_step( assert pole_angle == raw_angle episode.user_data["pole_angles"].append(pole_angle) + # Sometimes our pole is moving fast. We can look at the latest velocity + # estimate from our environment and log high velocities. + if np.abs(episode.last_info_for()["pole_angle_vel"]) > 0.25: + print("This is a fast pole!") + def on_episode_end( self, *, @@ -93,26 +124,17 @@ def on_episode_end( "after episode is done!" ) pole_angle = np.mean(episode.user_data["pole_angles"]) - print( - "episode {} (env-idx={}) ended with length {} and pole " - "angles {}".format( - episode.episode_id, env_index, episode.length, pole_angle - ) - ) episode.custom_metrics["pole_angle"] = pole_angle episode.hist_data["pole_angles"] = episode.user_data["pole_angles"] def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch, **kwargs): - print("returned sample batch of size {}".format(samples.count)) + # We can also do our own sanity checks here. + assert samples.count == 200, "I was expecting 200 here!" def on_train_result(self, *, algorithm, result: dict, **kwargs): - print( - "Algorithm.train() result: {} -> {} episodes".format( - algorithm, result["episodes_this_iter"] - ) - ) # you can mutate the result dict to add new fields to return result["callback_ok"] = True + # Normally, RLlib would aggregate any custom metric into a mean, max and min # of the given metric. # For the sake of this example, we will instead compute the variance and mean @@ -130,6 +152,7 @@ def on_learn_on_batch( self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs ) -> None: result["sum_actions_in_train_batch"] = train_batch["actions"].sum() + # Log the sum of actions in the train batch. print( "policy.learn_on_batch() result: {} -> sum actions: {}".format( policy, result["sum_actions_in_train_batch"] @@ -148,7 +171,6 @@ def on_postprocess_trajectory( original_batches: Dict[str, Tuple[Policy, SampleBatch]], **kwargs ): - print("postprocessed {} steps".format(postprocessed_batch.count)) if "num_batches" not in episode.custom_metrics: episode.custom_metrics["num_batches"] = 0 episode.custom_metrics["num_batches"] += 1 @@ -159,7 +181,7 @@ def on_postprocess_trajectory( config = ( PGConfig() - .environment("CartPole-v1") + .environment(CustomCartPole) .framework(args.framework) .callbacks(MyCallbacks) .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))