Skip to content

Commit

Permalink
Adding secondary metrics to gym env (facebookresearch#397)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#397

Differential Revision: D26526445

fbshipit-source-id: 69e9fc46b60ce682c73c1957cda2044239f7900f
  • Loading branch information
kittipatv authored and facebook-github-bot committed Feb 19, 2021
1 parent 4bcf941 commit 6419454
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ which performs the following pseudo-code
# run Agent on environment, and record rewards
rewards = evaluate_for_n_episodes(
n=num_eval_episodes, env=env, agent=agent, max_steps=max_steps
)
).rewards
Even on completely random data, DQN can learn a policy that obtains scores close to the maximum possible score of 200!
Expand Down
37 changes: 25 additions & 12 deletions reagent/gym/runners/gymrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
import pickle
from typing import Optional, Sequence
from typing import Callable, Optional, Sequence

import numpy as np
import torch.multiprocessing as mp
Expand All @@ -13,7 +13,7 @@
)
from reagent.gym.agents.agent import Agent
from reagent.gym.envs import EnvWrapper
from reagent.gym.types import Trajectory, Transition
from reagent.gym.types import Trajectory, Transition, EvaluationResults
from reagent.tensorboardX import SummaryWriterContext


Expand Down Expand Up @@ -68,7 +68,8 @@ def evaluate_for_n_episodes(
max_steps: Optional[int] = None,
gammas: Sequence[float] = (1.0,),
num_processes: int = 4,
) -> np.ndarray:
metrics_extractor: Optional[Callable[[Trajectory], np.ndarray]] = None,
) -> EvaluationResults:
"""Return an np array A of shape n x len(gammas)
where A[i, j] = ith episode evaluated with gamma=gammas[j].
Runs environments on num_processes, via multiprocessing.Pool.
Expand All @@ -81,20 +82,26 @@ def evaluate_one_episode(
agent: Agent,
max_steps: Optional[int],
gammas: Sequence[float],
) -> np.ndarray:
) -> EvaluationResults:
rewards = np.empty((len(gammas),))
trajectory = run_episode(
env=env, agent=agent, mdp_id=mdp_id, max_steps=max_steps
)
for i_gamma, gamma in enumerate(gammas):
rewards[i_gamma] = trajectory.calculate_cumulative_reward(gamma)
return rewards

rewards = None
metrics = None

if metrics_extractor is not None:
metrics = metrics_extractor(trajectory)

return rewards, metrics

eval_results = None
if num_processes > 1:
try:
with mp.Pool(num_processes) as pool:
rewards = unwrap_function_outputs(
eval_results = unwrap_function_outputs(
pool.map(
wrap_function_arguments(
evaluate_one_episode,
Expand All @@ -116,20 +123,26 @@ def evaluate_one_episode(
)

# if we didn't run multiprocessing, or it failed, try single-processing instead.
if rewards is None:
rewards = []
if eval_results is None:
eval_results = []
for i in range(n):
rewards.append(
eval_results.append(
evaluate_one_episode(
mdp_id=i, env=env, agent=agent, max_steps=max_steps, gammas=gammas
)
)

rewards = np.array(rewards)
rewards = np.array([r[0] for r in eval_results])
for i, gamma in enumerate(gammas):
gamma_rewards = rewards[:, i]
logger.info(
f"For gamma={gamma}, average reward is {gamma_rewards.mean()}\n"
f"Rewards list: {gamma_rewards}"
)
return rewards

metrics = None
if metrics_extractor is not None:
metrics = np.stack([r[1] for r in eval_results])

# FIXME: Also, return metrics & stuff
return EvaluationResults(rewards=rewards, metrics=metrics)
6 changes: 4 additions & 2 deletions reagent/gym/tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def train_policy(
return np.array(train_rewards)


# TODO: Return eval results?
def eval_policy(
env: EnvWrapper,
serving_policy: Policy,
Expand All @@ -218,13 +219,14 @@ def eval_policy(
else Agent.create_for_env(env, serving_policy)
)

eval_rewards = evaluate_for_n_episodes(
eval_results = evaluate_for_n_episodes(
n=num_eval_episodes,
env=env,
agent=agent,
max_steps=env.max_steps,
num_processes=1,
).squeeze(1)
)
eval_rewards = eval_results.rewards.squeeze(1)

logger.info("============Eval rewards==============")
logger.info(eval_rewards)
Expand Down
2 changes: 1 addition & 1 deletion reagent/gym/tests/test_gym_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def evaluate_cem(env, manager, num_eval_episodes: int):
agent = Agent.create_for_env(env, policy)
return evaluate_for_n_episodes(
n=num_eval_episodes, env=env, agent=agent, max_steps=env.max_steps
)
).rewards


def run_test_offline(
Expand Down
2 changes: 1 addition & 1 deletion reagent/gym/tests/test_world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def train_mdnrnn_and_train_on_embedded_env(
env=embed_env,
agent=agent,
num_processes=1,
)
).rewards
assert (
np.mean(rewards) >= passing_score_bar
), f"average reward doesn't pass our bar {passing_score_bar}"
Expand Down
6 changes: 6 additions & 0 deletions reagent/gym/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def calculate_cumulative_reward(self, gamma: float = 1.0):
return sum(reward * discount for reward, discount in zip(rewards, discounts))


@dataclass
class EvaluationResults:
rewards: np.ndarray
metrics: Optional[np.ndarray] = None


class Sampler(ABC):
"""Given scores, select the action."""

Expand Down
2 changes: 1 addition & 1 deletion reagent/workflow/gym_batch_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def evaluate_gym(
agent = Agent.create_for_env_with_serving_policy(env, policy)
rewards = evaluate_for_n_episodes(
n=num_eval_episodes, env=env, agent=agent, max_steps=max_steps
)
).rewards
avg_reward = np.mean(rewards)
logger.info(
f"Average reward over {num_eval_episodes} is {avg_reward}.\n"
Expand Down

0 comments on commit 6419454

Please sign in to comment.