Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
379 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import absolute_import | ||
from future import standard_library | ||
standard_library.install_aliases() # NOQA | ||
|
||
import chainer | ||
import chainer.functions as F | ||
|
||
from chainerrl.agents import iqn | ||
|
||
|
||
class DoubleIQN(iqn.IQN): | ||
"""Double IQN - Use primary network for target computation.""" | ||
|
||
def _compute_target_values(self, exp_batch): | ||
"""Compute a batch of target return distributions. | ||
Returns: | ||
chainer.Variable: (batch_size, N_prime). | ||
""" | ||
batch_next_state = exp_batch['next_state'] | ||
batch_rewards = exp_batch['reward'] | ||
batch_terminal = exp_batch['is_state_terminal'] | ||
batch_size = len(exp_batch['reward']) | ||
taus_tilde = self.xp.random.uniform( | ||
0, 1, size=(batch_size, self.quantile_thresholds_K)).astype('f') | ||
with chainer.using_config('train', False): | ||
if self.recurrent: | ||
next_tau2av, _ = self.model.n_step_forward( | ||
batch_next_state, | ||
exp_batch['next_recurrent_state'], | ||
output_mode='concat', | ||
) | ||
else: | ||
next_tau2av = self.model(batch_next_state) | ||
greedy_actions = next_tau2av(taus_tilde).greedy_actions | ||
taus_prime = self.xp.random.uniform( | ||
0, 1, | ||
size=(batch_size, self.quantile_thresholds_N_prime)).astype('f') | ||
if self.recurrent: | ||
target_next_tau2av, _ = self.target_model.n_step_forward( | ||
batch_next_state, | ||
exp_batch['next_recurrent_state'], | ||
output_mode='concat', | ||
) | ||
else: | ||
target_next_tau2av = self.target_model(batch_next_state) | ||
target_next_maxz = target_next_tau2av( | ||
taus_prime).evaluate_actions_as_quantiles(greedy_actions) | ||
|
||
batch_discount = exp_batch['discount'] | ||
assert batch_rewards.shape == (batch_size,) | ||
assert batch_terminal.shape == (batch_size,) | ||
assert batch_discount.shape == (batch_size,) | ||
batch_rewards = F.broadcast_to( | ||
batch_rewards[..., None], target_next_maxz.shape) | ||
batch_terminal = F.broadcast_to( | ||
batch_terminal[..., None], target_next_maxz.shape) | ||
batch_discount = F.broadcast_to( | ||
batch_discount[..., None], target_next_maxz.shape) | ||
|
||
return (batch_rewards | ||
+ batch_discount * (1.0 - batch_terminal) * target_next_maxz) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import unicode_literals | ||
from __future__ import absolute_import | ||
from builtins import * # NOQA | ||
from future import standard_library | ||
standard_library.install_aliases() # NOQA | ||
import argparse | ||
import functools | ||
import json | ||
import os | ||
|
||
import chainer | ||
import chainer.functions as F | ||
import chainer.links as L | ||
import gym | ||
import numpy as np | ||
|
||
import chainerrl | ||
from chainerrl import experiments | ||
from chainerrl import explorers | ||
from chainerrl import misc | ||
from chainerrl import replay_buffer | ||
from chainerrl.wrappers import atari_wrappers | ||
|
||
|
||
def parse_agent(agent): | ||
return {'IQN': chainerrl.agents.IQN, | ||
'DoubleIQN': chainerrl.agents.DoubleIQN}[agent] | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') | ||
parser.add_argument('--outdir', type=str, default='results', | ||
help='Directory path to save output files.' | ||
' If it does not exist, it will be created.') | ||
parser.add_argument('--seed', type=int, default=0, | ||
help='Random seed [0, 2 ** 31)') | ||
parser.add_argument('--gpu', type=int, default=0) | ||
parser.add_argument('--demo', action='store_true', default=False) | ||
parser.add_argument('--load', type=str, default=None) | ||
parser.add_argument('--final-exploration-frames', | ||
type=int, default=10 ** 6) | ||
parser.add_argument('--final-epsilon', type=float, default=0.01) | ||
parser.add_argument('--eval-epsilon', type=float, default=0.001) | ||
parser.add_argument('--steps', type=int, default=5 * 10 ** 7) | ||
parser.add_argument('--max-frames', type=int, | ||
default=30 * 60 * 60, # 30 minutes with 60 fps | ||
help='Maximum number of frames for each episode.') | ||
parser.add_argument('--replay-start-size', type=int, default=5 * 10 ** 4) | ||
parser.add_argument('--target-update-interval', | ||
type=int, default=10 ** 4) | ||
parser.add_argument('--agent', type=str, default='IQN', | ||
choices=['IQN', 'DoubleIQN']) | ||
parser.add_argument('--eval-interval', type=int, default=250000) | ||
parser.add_argument('--eval-n-steps', type=int, default=125000) | ||
parser.add_argument('--update-interval', type=int, default=4) | ||
parser.add_argument('--batch-size', type=int, default=32) | ||
parser.add_argument('--logging-level', type=int, default=20, | ||
help='Logging level. 10:DEBUG, 20:INFO etc.') | ||
parser.add_argument('--render', action='store_true', default=False, | ||
help='Render env states in a GUI window.') | ||
parser.add_argument('--monitor', action='store_true', default=False, | ||
help='Monitor env. Videos and additional information' | ||
' are saved as output files.') | ||
parser.add_argument('--batch-accumulator', type=str, default='mean', | ||
choices=['mean', 'sum']) | ||
parser.add_argument('--quantile-thresholds-N', type=int, default=64) | ||
parser.add_argument('--quantile-thresholds-N-prime', type=int, default=64) | ||
parser.add_argument('--quantile-thresholds-K', type=int, default=32) | ||
parser.add_argument('--n-best-episodes', type=int, default=200) | ||
args = parser.parse_args() | ||
|
||
import logging | ||
logging.basicConfig(level=args.logging_level) | ||
|
||
# Set a random seed used in ChainerRL. | ||
misc.set_random_seed(args.seed, gpus=(args.gpu,)) | ||
|
||
# Set different random seeds for train and test envs. | ||
train_seed = args.seed | ||
test_seed = 2 ** 31 - 1 - args.seed | ||
|
||
args.outdir = experiments.prepare_output_dir(args, args.outdir) | ||
print('Output files are saved in {}'.format(args.outdir)) | ||
|
||
def make_env(test): | ||
# Use different random seeds for train and test envs | ||
env_seed = test_seed if test else train_seed | ||
env = atari_wrappers.wrap_deepmind( | ||
atari_wrappers.make_atari(args.env, max_frames=args.max_frames), | ||
episode_life=not test, | ||
clip_rewards=not test) | ||
env.seed(int(env_seed)) | ||
if test: | ||
# Randomize actions like epsilon-greedy in evaluation as well | ||
env = chainerrl.wrappers.RandomizeAction(env, args.eval_epsilon) | ||
if args.monitor: | ||
env = gym.wrappers.Monitor( | ||
env, args.outdir, | ||
mode='evaluation' if test else 'training') | ||
if args.render: | ||
env = chainerrl.wrappers.Render(env) | ||
return env | ||
|
||
env = make_env(test=False) | ||
eval_env = make_env(test=True) | ||
n_actions = env.action_space.n | ||
|
||
q_func = chainerrl.agents.iqn.ImplicitQuantileQFunction( | ||
psi=chainerrl.links.Sequence( | ||
L.Convolution2D(None, 32, 8, stride=4), | ||
F.relu, | ||
L.Convolution2D(None, 64, 4, stride=2), | ||
F.relu, | ||
L.Convolution2D(None, 64, 3, stride=1), | ||
F.relu, | ||
functools.partial(F.reshape, shape=(-1, 3136)), | ||
), | ||
phi=chainerrl.links.Sequence( | ||
chainerrl.agents.iqn.CosineBasisLinear(64, 3136), | ||
F.relu, | ||
), | ||
f=chainerrl.links.Sequence( | ||
L.Linear(None, 512), | ||
F.relu, | ||
L.Linear(None, n_actions), | ||
), | ||
) | ||
|
||
# Draw the computational graph and save it in the output directory. | ||
fake_obss = np.zeros((4, 84, 84), dtype=np.float32)[None] | ||
fake_taus = np.zeros(32, dtype=np.float32)[None] | ||
chainerrl.misc.draw_computational_graph( | ||
[q_func(fake_obss)(fake_taus)], | ||
os.path.join(args.outdir, 'model')) | ||
|
||
# Use the same hyper parameters as https://arxiv.org/abs/1710.10044 | ||
opt = chainer.optimizers.Adam(5e-5, eps=1e-2 / args.batch_size) | ||
opt.setup(q_func) | ||
|
||
rbuf = replay_buffer.ReplayBuffer(10 ** 6) | ||
|
||
explorer = explorers.LinearDecayEpsilonGreedy( | ||
1.0, args.final_epsilon, | ||
args.final_exploration_frames, | ||
lambda: np.random.randint(n_actions)) | ||
|
||
def phi(x): | ||
# Feature extractor | ||
return np.asarray(x, dtype=np.float32) / 255 | ||
|
||
Agent = parse_agent(args.agent) | ||
agent = Agent( | ||
q_func, opt, rbuf, gpu=args.gpu, gamma=0.99, | ||
explorer=explorer, replay_start_size=args.replay_start_size, | ||
target_update_interval=args.target_update_interval, | ||
update_interval=args.update_interval, | ||
batch_accumulator=args.batch_accumulator, | ||
phi=phi, | ||
quantile_thresholds_N=args.quantile_thresholds_N, | ||
quantile_thresholds_N_prime=args.quantile_thresholds_N_prime, | ||
quantile_thresholds_K=args.quantile_thresholds_K, | ||
) | ||
|
||
if args.load: | ||
agent.load(args.load) | ||
|
||
if args.demo: | ||
eval_stats = experiments.eval_performance( | ||
env=eval_env, | ||
agent=agent, | ||
n_steps=args.eval_n_steps, | ||
n_episodes=None, | ||
) | ||
print('n_steps: {} mean: {} median: {} stdev {}'.format( | ||
args.eval_n_steps, eval_stats['mean'], eval_stats['median'], | ||
eval_stats['stdev'])) | ||
else: | ||
experiments.train_agent_with_evaluation( | ||
agent=agent, | ||
env=env, | ||
steps=args.steps, | ||
eval_n_steps=args.eval_n_steps, | ||
eval_n_episodes=None, | ||
eval_interval=args.eval_interval, | ||
outdir=args.outdir, | ||
save_best_so_far_agent=True, | ||
eval_env=eval_env, | ||
) | ||
|
||
dir_of_best_network = os.path.join(args.outdir, "best") | ||
agent.load(dir_of_best_network) | ||
|
||
# run 200 evaluation episodes, each capped at 30 mins of play | ||
stats = experiments.evaluator.eval_performance( | ||
env=eval_env, | ||
agent=agent, | ||
n_steps=None, | ||
n_episodes=args.n_best_episodes, | ||
max_episode_len=args.max_frames / 4, | ||
logger=None) | ||
with open(os.path.join(args.outdir, 'bestscores.json'), 'w') as f: | ||
# temporary hack to handle python 2/3 support issues. | ||
# json dumps does not support non-string literal dict keys | ||
json_stats = json.dumps(stats) | ||
print(str(json_stats), file=f) | ||
print("The results of the best scoring network:") | ||
for stat in stats: | ||
print(str(stat) + ":" + str(stats[stat])) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/bin/bash | ||
|
||
set -Ceu | ||
|
||
outdir=$(mktemp -d) | ||
|
||
gpu="$1" | ||
|
||
# atari/double_iqn | ||
python examples/atari/train_double_iqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/double_iqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu | ||
model=$(find $outdir/atari/double_iqn -name "*_finish") | ||
python examples/atari/train_double_iqn.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import absolute_import | ||
from future import standard_library | ||
from builtins import * # NOQA | ||
standard_library.install_aliases() # NOQA | ||
|
||
import chainer | ||
import chainer.functions as F | ||
import chainer.links as L | ||
from chainer import testing | ||
|
||
import basetest_dqn_like as base | ||
from basetest_training import _TestBatchTrainingMixin | ||
import chainerrl | ||
from chainerrl.agents import double_iqn | ||
from chainerrl.agents import iqn | ||
|
||
|
||
@testing.parameterize(*testing.product({ | ||
'quantile_thresholds_N': [1, 5], | ||
'quantile_thresholds_N_prime': [1, 7], | ||
})) | ||
class TestDoubleIQNOnDiscreteABC( | ||
_TestBatchTrainingMixin, base._TestDQNOnDiscreteABC): | ||
|
||
def make_q_func(self, env): | ||
obs_size = env.observation_space.low.size | ||
hidden_size = 64 | ||
return iqn.ImplicitQuantileQFunction( | ||
psi=chainerrl.links.Sequence( | ||
L.Linear(obs_size, hidden_size), | ||
F.relu, | ||
), | ||
phi=chainerrl.links.Sequence( | ||
iqn.CosineBasisLinear(32, hidden_size), | ||
F.relu, | ||
), | ||
f=L.Linear(hidden_size, env.action_space.n), | ||
) | ||
|
||
def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): | ||
return double_iqn.DoubleIQN( | ||
q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, | ||
replay_start_size=100, target_update_interval=100, | ||
quantile_thresholds_N=self.quantile_thresholds_N, | ||
quantile_thresholds_N_prime=self.quantile_thresholds_N_prime, | ||
) | ||
|
||
|
||
class TestDoubleIQNOnDiscretePOABC( | ||
_TestBatchTrainingMixin, base._TestDQNOnDiscretePOABC): | ||
|
||
def make_q_func(self, env): | ||
obs_size = env.observation_space.low.size | ||
hidden_size = 64 | ||
return iqn.StatelessRecurrentImplicitQuantileQFunction( | ||
psi=chainerrl.links.StatelessRecurrentSequential( | ||
L.Linear(obs_size, hidden_size), | ||
F.relu, | ||
L.NStepRNNTanh(1, hidden_size, hidden_size, 0), | ||
), | ||
phi=chainerrl.links.Sequence( | ||
chainerrl.agents.iqn.CosineBasisLinear(32, hidden_size), | ||
F.relu, | ||
), | ||
f=L.Linear(hidden_size, env.action_space.n, | ||
initialW=chainer.initializers.LeCunNormal(1e-1)), | ||
) | ||
|
||
def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): | ||
return double_iqn.DoubleIQN( | ||
q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, | ||
replay_start_size=100, target_update_interval=100, | ||
quantile_thresholds_N=32, | ||
quantile_thresholds_N_prime=32, | ||
recurrent=True, | ||
) |
Oops, something went wrong.