Skip to content

Commit

Permalink
Update episode_v2.py with last_info_for (ray-project#37382)
Browse files Browse the repository at this point in the history
Adding last_info_for method to return last info dict from the environment. This is currently present in Episode but not EpisodeV2 and is especially useful for custom environment values.

Signed-off-by: Victor Mao <vctrm67@gmail.com>
Co-authored-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
2 people authored and arvind-chandra committed Aug 31, 2023
1 parent 8d97d41 commit 480cd26
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
7 changes: 6 additions & 1 deletion rllib/evaluation/episode_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import AgentID, EnvID, PolicyID, TensorType
from ray.rllib.utils.typing import AgentID, EnvID, EnvInfoDict, PolicyID, TensorType

if TYPE_CHECKING:
from ray.rllib.algorithms.callbacks import DefaultCallbacks
Expand Down Expand Up @@ -370,6 +370,11 @@ def is_truncated(self, agent_id: AgentID) -> bool:
def set_last_info(self, agent_id: AgentID, info: Dict):
self._last_infos[agent_id] = info

def last_info_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID
) -> Optional[EnvInfoDict]:
return self._last_infos.get(agent_id)

@property
def length(self):
return self.total_env_steps
52 changes: 37 additions & 15 deletions rllib/examples/custom_metrics_and_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from typing import Dict, Tuple
import argparse
import gymnasium as gym
import numpy as np
import os

Expand All @@ -31,6 +32,31 @@
parser.add_argument("--stop-iters", type=int, default=2000)


# Create a custom CartPole environment that maintains an estimate of velocity
class CustomCartPole(gym.Env):
def __init__(self, config):
self.env = gym.make("CartPole-v1")
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space
self._pole_angle_vel = 0.0
self.last_angle = 0.0

def reset(self, *, seed=None, options=None):
self._pole_angle_vel = 0.0
obs, info = self.env.reset()
self.last_angle = obs[2]
return obs, info

def step(self, action):
obs, rew, term, trunc, info = self.env.step(action)
angle = obs[2]
self._pole_angle_vel = (
0.5 * (angle - self.last_angle) + 0.5 * self._pole_angle_vel
)
info["pole_angle_vel"] = self._pole_angle_vel
return obs, rew, term, trunc, info


class MyCallbacks(DefaultCallbacks):
def on_episode_start(
self,
Expand All @@ -48,7 +74,7 @@ def on_episode_start(
"ERROR: `on_episode_start()` callback should be called right "
"after env reset!"
)
print("episode {} (env-idx={}) started.".format(episode.episode_id, env_index))
# Create lists to store angles in
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []

Expand All @@ -72,6 +98,11 @@ def on_episode_step(
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)

# Sometimes our pole is moving fast. We can look at the latest velocity
# estimate from our environment and log high velocities.
if np.abs(episode.last_info_for()["pole_angle_vel"]) > 0.25:
print("This is a fast pole!")

def on_episode_end(
self,
*,
Expand All @@ -93,26 +124,17 @@ def on_episode_end(
"after episode is done!"
)
pole_angle = np.mean(episode.user_data["pole_angles"])
print(
"episode {} (env-idx={}) ended with length {} and pole "
"angles {}".format(
episode.episode_id, env_index, episode.length, pole_angle
)
)
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]

def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch, **kwargs):
print("returned sample batch of size {}".format(samples.count))
# We can also do our own sanity checks here.
assert samples.count == 200, "I was expecting 200 here!"

def on_train_result(self, *, algorithm, result: dict, **kwargs):
print(
"Algorithm.train() result: {} -> {} episodes".format(
algorithm, result["episodes_this_iter"]
)
)
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True

# Normally, RLlib would aggregate any custom metric into a mean, max and min
# of the given metric.
# For the sake of this example, we will instead compute the variance and mean
Expand All @@ -130,6 +152,7 @@ def on_learn_on_batch(
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
) -> None:
result["sum_actions_in_train_batch"] = train_batch["actions"].sum()
# Log the sum of actions in the train batch.
print(
"policy.learn_on_batch() result: {} -> sum actions: {}".format(
policy, result["sum_actions_in_train_batch"]
Expand All @@ -148,7 +171,6 @@ def on_postprocess_trajectory(
original_batches: Dict[str, Tuple[Policy, SampleBatch]],
**kwargs
):
print("postprocessed {} steps".format(postprocessed_batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
Expand All @@ -159,7 +181,7 @@ def on_postprocess_trajectory(

config = (
PGConfig()
.environment("CartPole-v1")
.environment(CustomCartPole)
.framework(args.framework)
.callbacks(MyCallbacks)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
Expand Down

0 comments on commit 480cd26

Please sign in to comment.