Skip to content

Commit

Permalink
Merge pull request #358 from prabhatnagarajan/train_and_eval
Browse files Browse the repository at this point in the history
Changes variable names in train_agent_with_evaluation
  • Loading branch information
muupan authored Nov 26, 2018
2 parents 0f32aae + f3fa564 commit ee67d76
Show file tree
Hide file tree
Showing 18 changed files with 49 additions and 49 deletions.
32 changes: 16 additions & 16 deletions chainerrl/experiments/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def train_agent(agent, env, steps, outdir, max_episode_len=None,
def train_agent_with_evaluation(agent,
env,
steps,
eval_n_runs,
eval_n_episodes,
eval_interval,
outdir,
max_episode_len=None,
train_max_episode_len=None,
step_offset=0,
eval_max_episode_len=None,
eval_env=None,
Expand All @@ -103,27 +103,27 @@ def train_agent_with_evaluation(agent,
save_best_so_far_agent=True,
logger=None,
):
"""Train an agent while regularly evaluating it.
"""Train an agent while periodically evaluating it.
Args:
agent: Agent to train.
env: Environment train the againt against.
steps (int): Number of total time steps for training.
eval_n_runs (int): Number of runs for each time of evaluation.
agent: A chainerrl.agent.Agent
env: Environment train the agent against.
steps (int): Total number of timesteps for training.
eval_n_episodes (int): Number of episodes at each evaluation phase
eval_interval (int): Interval of evaluation.
outdir (str): Path to the directory to output things.
max_episode_len (int): Maximum episode length.
outdir (str): Path to the directory to output data.
train_max_episode_len (int): Maximum episode length during training.
step_offset (int): Time step from which training starts.
eval_max_episode_len (int or None): Maximum episode length of
evaluation runs. If set to None, max_episode_len is used instead.
evaluation runs. If None, train_max_episode_len is used instead.
eval_env: Environment used for evaluation.
successful_score (float): Finish training if the mean score is greater
or equal to this value if not None
than or equal to this value if not None
step_hooks (list): List of callable objects that accepts
(env, agent, step) as arguments. They are called every step.
See chainerrl.experiments.hooks.
save_best_so_far_agent (bool): If set to True, after each evaluation,
if the score (= mean return of evaluation episodes) exceeds
save_best_so_far_agent (bool): If set to True, after each evaluation
phase, if the score (= mean return of evaluation episodes) exceeds
the best-so-far score, the current agent is saved.
logger (logging.Logger): Logger used in this function.
"""
Expand All @@ -136,10 +136,10 @@ def train_agent_with_evaluation(agent,
eval_env = env

if eval_max_episode_len is None:
eval_max_episode_len = max_episode_len
eval_max_episode_len = train_max_episode_len

evaluator = Evaluator(agent=agent,
n_runs=eval_n_runs,
n_runs=eval_n_episodes,
eval_interval=eval_interval, outdir=outdir,
max_episode_len=eval_max_episode_len,
env=eval_env,
Expand All @@ -150,7 +150,7 @@ def train_agent_with_evaluation(agent,

train_agent(
agent, env, steps, outdir,
max_episode_len=max_episode_len,
max_episode_len=train_max_episode_len,
step_offset=step_offset,
evaluator=evaluator,
successful_score=successful_score,
Expand Down
4 changes: 2 additions & 2 deletions examples/ale/train_categorical_dqn_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def phi(x):
else:
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_n_runs=args.eval_n_runs, eval_interval=args.eval_interval,
eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval,
outdir=args.outdir,
save_best_so_far_agent=False,
max_episode_len=args.max_episode_len,
train_max_episode_len=args.max_episode_len,
eval_env=eval_env,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/ale/train_dqn_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ def phi(x):
else:
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_n_runs=args.eval_n_runs, eval_interval=args.eval_interval,
eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval,
outdir=args.outdir,
save_best_so_far_agent=False,
max_episode_len=args.max_episode_len,
train_max_episode_len=args.max_episode_len,
eval_env=eval_env,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/ale/train_ppo_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def clip_eps_setter(env, agent, value):
eval_env=eval_env,
outdir=args.outdir,
steps=args.steps,
eval_n_runs=args.eval_n_runs,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
max_episode_len=args.max_episode_len,
train_max_episode_len=args.max_episode_len,
save_best_so_far_agent=False,
step_hooks=[
lr_decay_hook,
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/dqn/train_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ def phi(x):
else:
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_n_runs=args.eval_n_runs, eval_interval=args.eval_interval,
eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval,
outdir=args.outdir,
save_best_so_far_agent=False,
max_episode_len=args.max_episode_len,
train_max_episode_len=args.max_episode_len,
eval_env=eval_env,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/gym/train_categorical_dqn_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def make_env(test):
else:
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_n_runs=args.eval_n_runs, eval_interval=args.eval_interval,
eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval,
outdir=args.outdir, eval_env=eval_env,
max_episode_len=timestep_limit)
train_max_episode_len=timestep_limit)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions examples/gym/train_ddpg_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def random_action():
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_env=eval_env,
eval_n_runs=args.eval_n_runs, eval_interval=args.eval_interval,
eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval,
outdir=args.outdir,
max_episode_len=timestep_limit)
train_max_episode_len=timestep_limit)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions examples/gym/train_dqn_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def make_env(test):
else:
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_n_runs=args.eval_n_runs, eval_interval=args.eval_interval,
eval_n_episodes=args.eval_n_runs, eval_interval=args.eval_interval,
outdir=args.outdir, eval_env=eval_env,
max_episode_len=timestep_limit)
train_max_episode_len=timestep_limit)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions examples/gym/train_pcl_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ def make_env(process_idx, test):
eval_env=make_env(0, test=True),
outdir=args.outdir,
steps=args.steps,
eval_n_runs=args.eval_n_runs,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
max_episode_len=timestep_limit)
train_max_episode_len=timestep_limit)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions examples/gym/train_ppo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def clip_eps_setter(env, agent, value):
eval_env=make_env(True),
outdir=args.outdir,
steps=args.steps,
eval_n_runs=args.eval_n_runs,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
max_episode_len=timestep_limit,
train_max_episode_len=timestep_limit,
save_best_so_far_agent=False,
step_hooks=[
lr_decay_hook,
Expand Down
4 changes: 2 additions & 2 deletions examples/gym/train_reinforce_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def make_env(test):
eval_env=eval_env,
outdir=args.outdir,
steps=args.steps,
eval_n_runs=args.eval_n_runs,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
max_episode_len=timestep_limit)
train_max_episode_len=timestep_limit)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions examples/gym/train_trpo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def make_env(test):
eval_env=make_env(test=True),
outdir=args.outdir,
steps=args.steps,
eval_n_runs=args.eval_n_runs,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
max_episode_len=timestep_limit,
train_max_episode_len=timestep_limit,
)


Expand Down
4 changes: 2 additions & 2 deletions examples/quickstart/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@
"chainerrl.experiments.train_agent_with_evaluation(\n",
" agent, env,\n",
" steps=2000, # Train the agent for 2000 steps\n",
" eval_n_runs=10, # 10 episodes are sampled for each evaluation\n",
" max_episode_len=200, # Maximum length of each episodes\n",
" eval_n_episodes=10, # 10 episodes are sampled for each evaluation\n",
" train_max_episode_len=200, # Maximum length of each episode\n",
" eval_interval=1000, # Evaluate the agent after every 1000 steps\n",
" outdir='result') # Save everything to 'result' directory"
]
Expand Down
2 changes: 1 addition & 1 deletion tests/agents_tests/basetest_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _test_training(self, gpu, steps=5000, load_model=False,
# Train
train_agent_with_evaluation(
agent=agent, env=env, steps=steps, outdir=self.tmpdir,
eval_interval=200, eval_n_runs=5, successful_score=1,
eval_interval=200, eval_n_episodes=5, successful_score=1,
eval_env=test_env)

agent.stop_episode()
Expand Down
4 changes: 2 additions & 2 deletions tests/agents_tests/test_pcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def phi(x):
eval_env=make_env(0, True),
outdir=self.outdir,
steps=steps,
max_episode_len=2,
train_max_episode_len=2,
eval_interval=200,
eval_n_runs=5,
eval_n_episodes=5,
successful_score=1)

agent.stop_episode()
Expand Down
4 changes: 2 additions & 2 deletions tests/agents_tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def _test_abc(self, steps=1000000,
steps=steps,
outdir=self.tmpdir,
eval_interval=200,
eval_n_runs=50,
eval_n_episodes=50,
successful_score=successful_return,
eval_env=test_env,
max_episode_len=max_episode_len,
train_max_episode_len=max_episode_len,
)

agent.stop_episode()
Expand Down
4 changes: 2 additions & 2 deletions tests/agents_tests/test_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def phi(x):
eval_env=make_env(0, True),
outdir=self.outdir,
steps=steps,
max_episode_len=2,
train_max_episode_len=2,
eval_interval=500,
eval_n_runs=5,
eval_n_episodes=5,
successful_score=1)

# Test
Expand Down
4 changes: 2 additions & 2 deletions tests/agents_tests/test_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def _test_abc(self, steps=1000000,
steps=steps,
outdir=self.tmpdir,
eval_interval=200,
eval_n_runs=5,
eval_n_episodes=5,
successful_score=successful_return,
max_episode_len=max_episode_len,
train_max_episode_len=max_episode_len,
)

agent.stop_episode()
Expand Down

0 comments on commit ee67d76

Please sign in to comment.